"""
LRU缓存实现模块

提供线程安全的LRU缓存机制，用于缓存Embedding向量和查询结果。
支持TTL（生存时间）、容量限制、统计信息等功能。
"""

import time
from typing import Any, Dict, Optional, Tuple, List
from collections import OrderedDict
from threading import Lock
from dataclasses import dataclass
import hashlib
import json


@dataclass
class CacheStats:
    """缓存统计信息"""
    hits: int = 0
    misses: int = 0
    evictions: int = 0
    expirations: int = 0
    total_requests: int = 0
    
    @property
    def hit_rate(self) -> float:
        """计算缓存命中率"""
        if self.total_requests == 0:
            return 0.0
        return self.hits / self.total_requests


class CacheEntry:
    """缓存条目"""
    
    def __init__(self, key: str, value: Any, ttl: Optional[int] = None):
        """
        初始化缓存条目
        
        Args:
            key: 缓存键
            value: 缓存值
            ttl: 生存时间（秒），None表示永不过期
        """
        self.key = key
        self.value = value
        self.created_at = time.time()
        self.ttl = ttl
        self.access_count = 0
        self.last_access = self.created_at
    
    def is_expired(self) -> bool:
        """检查缓存是否过期"""
        if self.ttl is None:
            return False
        return time.time() - self.created_at > self.ttl
    
    def touch(self):
        """更新访问时间和计数"""
        self.last_access = time.time()
        self.access_count += 1


class LRUCache:
    """
    线程安全的LRU缓存实现
    
    特性：
    - LRU（最近最少使用）淘汰策略
    - TTL支持（可选）
    - 线程安全
    - 统计信息
    - 批量操作
    """
    
    def __init__(self, max_size: int = 1000, ttl: Optional[int] = None):
        """
        初始化LRU缓存
        
        Args:
            max_size: 缓存最大容量
            ttl: 默认生存时间（秒），None表示永不过期
        """
        self.max_size = max_size
        self.default_ttl = ttl
        self._cache: OrderedDict[str, CacheEntry] = OrderedDict()
        self._lock = Lock()
        self._stats = CacheStats()
    
    def get(self, key: str, default: Any = None) -> Any:
        """
        获取缓存值
        
        Args:
            key: 缓存键
            default: 键不存在时的默认值
            
        Returns:
            缓存值或默认值
        """
        with self._lock:
            self._stats.total_requests += 1
            
            if key not in self._cache:
                self._stats.misses += 1
                return default
            
            entry = self._cache[key]
            
            # 检查是否过期
            if entry.is_expired():
                self._cache.pop(key)
                self._stats.expirations += 1
                self._stats.misses += 1
                return default
            
            # 更新访问信息并移到末尾（最近使用）
            entry.touch()
            self._cache.move_to_end(key)
            self._stats.hits += 1
            
            return entry.value
    
    def set(self, key: str, value: Any, ttl: Optional[int] = None) -> None:
        """
        设置缓存值
        
        Args:
            key: 缓存键
            value: 缓存值
            ttl: 生存时间（秒），None使用默认TTL
        """
        with self._lock:
            # 如果键已存在，先删除
            if key in self._cache:
                self._cache.pop(key)
            
            # 检查容量限制
            while len(self._cache) >= self.max_size:
                # 移除最旧的项（LRU）
                self._cache.popitem(last=False)
                self._stats.evictions += 1
            
            # 添加新条目
            actual_ttl = ttl if ttl is not None else self.default_ttl
            entry = CacheEntry(key, value, actual_ttl)
            self._cache[key] = entry
    
    def delete(self, key: str) -> bool:
        """
        删除缓存项
        
        Args:
            key: 缓存键
            
        Returns:
            是否成功删除
        """
        with self._lock:
            if key in self._cache:
                self._cache.pop(key)
                return True
            return False
    
    def exists(self, key: str) -> bool:
        """
        检查键是否存在且未过期
        
        Args:
            key: 缓存键
            
        Returns:
            是否存在
        """
        with self._lock:
            if key not in self._cache:
                return False
            
            entry = self._cache[key]
            if entry.is_expired():
                self._cache.pop(key)
                self._stats.expirations += 1
                return False
            
            return True
    
    def clear(self) -> None:
        """清空所有缓存"""
        with self._lock:
            self._cache.clear()
    
    def cleanup_expired(self) -> int:
        """
        清理所有过期条目
        
        Returns:
            清理的条目数量
        """
        with self._lock:
            expired_keys = [
                key for key, entry in self._cache.items()
                if entry.is_expired()
            ]
            
            for key in expired_keys:
                self._cache.pop(key)
                self._stats.expirations += 1
            
            return len(expired_keys)
    
    def get_stats(self) -> CacheStats:
        """获取缓存统计信息"""
        with self._lock:
            return CacheStats(
                hits=self._stats.hits,
                misses=self._stats.misses,
                evictions=self._stats.evictions,
                expirations=self._stats.expirations,
                total_requests=self._stats.total_requests
            )
    
    def reset_stats(self) -> None:
        """重置统计信息"""
        with self._lock:
            self._stats = CacheStats()
    
    def size(self) -> int:
        """获取当前缓存大小"""
        with self._lock:
            return len(self._cache)
    
    def get_all_keys(self) -> List[str]:
        """获取所有缓存键"""
        with self._lock:
            return list(self._cache.keys())
    
    def get_info(self) -> Dict[str, Any]:
        """
        获取缓存详细信息
        
        Returns:
            包含配置和统计的字典
        """
        stats = self.get_stats()
        with self._lock:
            return {
                "max_size": self.max_size,
                "current_size": len(self._cache),
                "default_ttl": self.default_ttl,
                "stats": {
                    "hits": stats.hits,
                    "misses": stats.misses,
                    "evictions": stats.evictions,
                    "expirations": stats.expirations,
                    "total_requests": stats.total_requests,
                    "hit_rate": f"{stats.hit_rate:.2%}"
                }
            }


class EmbeddingCache(LRUCache):
    """
    专门用于Embedding向量的缓存
    
    提供文本到向量的缓存功能，自动生成缓存键
    """
    
    @staticmethod
    def _generate_key(text: str, model: str = "") -> str:
        """
        生成缓存键
        
        Args:
            text: 输入文本
            model: 模型名称
            
        Returns:
            缓存键
        """
        content = f"{model}:{text}"
        return hashlib.md5(content.encode('utf-8')).hexdigest()
    
    def get_embedding(self, text: str, model: str = "") -> Optional[List[float]]:
        """
        获取文本的Embedding向量
        
        Args:
            text: 输入文本
            model: 模型名称
            
        Returns:
            向量或None
        """
        key = self._generate_key(text, model)
        return self.get(key)
    
    def set_embedding(self, text: str, embedding: List[float], 
                     model: str = "", ttl: Optional[int] = None) -> None:
        """
        缓存文本的Embedding向量
        
        Args:
            text: 输入文本
            embedding: 向量
            model: 模型名称
            ttl: 生存时间
        """
        key = self._generate_key(text, model)
        self.set(key, embedding, ttl)
    
    def get_batch_embeddings(self, texts: List[str], model: str = "") -> Dict[str, Optional[List[float]]]:
        """
        批量获取Embedding向量
        
        Args:
            texts: 文本列表
            model: 模型名称
            
        Returns:
            文本到向量的映射字典，未命中的为None
        """
        result = {}
        for text in texts:
            result[text] = self.get_embedding(text, model)
        return result
    
    def set_batch_embeddings(self, embeddings: Dict[str, List[float]], 
                           model: str = "", ttl: Optional[int] = None) -> None:
        """
        批量缓存Embedding向量
        
        Args:
            embeddings: 文本到向量的映射字典
            model: 模型名称
            ttl: 生存时间
        """
        for text, embedding in embeddings.items():
            self.set_embedding(text, embedding, model, ttl)


class QueryCache(LRUCache):
    """
    查询结果缓存
    
    用于缓存检索查询的结果
    """
    
    @staticmethod
    def _generate_key(query: str, filters: Optional[Dict[str, Any]] = None, 
                     top_k: int = 10) -> str:
        """
        生成查询缓存键
        
        Args:
            query: 查询文本
            filters: 过滤条件
            top_k: 返回结果数量
            
        Returns:
            缓存键
        """
        # 将查询参数序列化为稳定的字符串
        cache_obj = {
            "query": query,
            "filters": filters or {},
            "top_k": top_k
        }
        content = json.dumps(cache_obj, sort_keys=True)
        return hashlib.md5(content.encode('utf-8')).hexdigest()
    
    def get_query_result(self, query: str, filters: Optional[Dict[str, Any]] = None,
                        top_k: int = 10) -> Optional[Any]:
        """
        获取查询结果
        
        Args:
            query: 查询文本
            filters: 过滤条件
            top_k: 返回结果数量
            
        Returns:
            查询结果或None
        """
        key = self._generate_key(query, filters, top_k)
        return self.get(key)
    
    def set_query_result(self, query: str, result: Any, 
                        filters: Optional[Dict[str, Any]] = None,
                        top_k: int = 10, ttl: Optional[int] = None) -> None:
        """
        缓存查询结果
        
        Args:
            query: 查询文本
            result: 查询结果
            filters: 过滤条件
            top_k: 返回结果数量
            ttl: 生存时间
        """
        key = self._generate_key(query, filters, top_k)
        self.set(key, result, ttl)
