"""
ChromaDB 数据访问对象（DAO）

提供 ChromaDB 向量数据库的统一访问接口。
负责向量存储、相似度检索和元数据管理。
"""

from pathlib import Path
from typing import Dict, List, Optional, Any

import chromadb
from chromadb.api import ClientAPI
from chromadb.config import Settings
from chromadb.api.models.Collection import Collection
from chromadb.api.types import Embedding, Embeddings, Metadata, Metadatas

from src.core.logger import get_logger


class ChromaDAO:
    """
    ChromaDB 数据访问对象
    
    负责管理 ChromaDB 向量数据库，实现向量存储和检索功能。
    使用本地持久化模式，支持 1-10 万级题目规模。
    """
    
    def __init__(
        self,
        persist_dir: str,
        collection_name: str = "questions_collection",
        distance_metric: str = "cosine",
        logger=None
    ):
        """
        初始化 ChromaDB DAO
        
        Args:
            persist_dir: 持久化目录路径
            collection_name: Collection 名称
            distance_metric: 距离度量方式（cosine/l2/ip）
            logger: 日志记录器实例
        """
        self.persist_dir = persist_dir
        self.collection_name = collection_name
        self.distance_metric = distance_metric
        self.logger = logger or get_logger()
        
        self._ensure_persist_directory()
        self._client: Optional[ClientAPI] = None
        self._collection: Optional[Collection] = None
        
    def _ensure_persist_directory(self):
        """确保持久化目录存在"""
        persist_path = Path(self.persist_dir)
        persist_path.mkdir(parents=True, exist_ok=True)
        
    def connect(self) -> ClientAPI:
        """
        创建 ChromaDB 客户端连接
        
        Returns:
            ClientAPI: ChromaDB 客户端
        """
        try:
            if self._client is None:
                self._client = chromadb.PersistentClient(
                    path=self.persist_dir,
                    settings=Settings(
                        anonymized_telemetry=False,  # 禁用遥测
                        allow_reset=True
                    )
                )
                self.logger.debug(f"成功连接到 ChromaDB: {self.persist_dir}")
            
            return self._client
            
        except Exception as e:
            self.logger.error(f"ChromaDB 连接失败: {e}")
            raise
    
    def initialize_collection(self):
        """
        初始化或获取 Collection
        
        如果 Collection 不存在则创建，存在则获取。
        """
        try:
            client = self.connect()
            
            # 尝试获取已存在的 Collection
            try:
                self._collection = client.get_collection(name=self.collection_name)
                self.logger.info(f"获取到已存在的 Collection: {self.collection_name}")
            except Exception:
                # Collection 不存在，创建新的
                self._collection = client.create_collection(
                    name=self.collection_name,
                    metadata={
                        "description": "题目向量存储",
                        "distance_metric": self.distance_metric,
                        "hnsw:space": self.distance_metric
                    }
                )
                self.logger.info(f"成功创建 Collection: {self.collection_name}")
            
        except Exception as e:
            self.logger.error(f"初始化 Collection 失败: {e}")
            raise
    
    def get_collection(self) -> Collection:
        """
        获取 Collection 实例
        
        Returns:
            Collection: ChromaDB Collection 对象
        """
        if self._collection is None:
            self.initialize_collection()
        
        if self._collection is None:
            raise RuntimeError("Collection 未初始化")
            
        return self._collection
    
    def count_documents(self) -> int:
        """
        获取文档数量
        
        Returns:
            int: Collection 中的文档总数
        """
        try:
            collection = self.get_collection()
            return collection.count()
        except Exception as e:
            self.logger.error(f"获取文档数量失败: {e}")
            raise
    
    # -------------------------------------------------------------------------
    # 向量存储操作
    # -------------------------------------------------------------------------
    
    def add_document(
        self,
        question_id: str,
        content: str,
        embedding: Embedding,
        metadata: Metadata
    ) -> bool:
        """
        添加单个文档
        
        Args:
            question_id: 题目ID（作为文档ID）
            content: 题目内容
            embedding: 内容的向量表示
            metadata: 元数据（分类、难度等）
            
        Returns:
            bool: 添加是否成功
        """
        try:
            collection = self.get_collection()
            
            collection.add(
                ids=[question_id],
                documents=[content],
                embeddings=[embedding],
                metadatas=[metadata]
            )
            
            self.logger.debug(f"成功添加文档: {question_id}")
            return True
            
        except Exception as e:
            self.logger.error(f"添加文档失败: {e}")
            raise
    
    def add_documents_batch(
        self,
        question_ids: List[str],
        contents: List[str],
        embeddings: Embeddings,
        metadatas: Metadatas
    ) -> bool:
        """
        批量添加文档
        
        Args:
            question_ids: 题目ID列表
            contents: 内容列表
            embeddings: 向量列表
            metadatas: 元数据列表
            
        Returns:
            bool: 添加是否成功
        """
        try:
            if not (len(question_ids) == len(contents) == len(embeddings) == len(metadatas)):
                raise ValueError("批量添加时，所有列表长度必须相同")
            
            collection = self.get_collection()
            
            collection.add(
                ids=question_ids,
                documents=contents,
                embeddings=embeddings,
                metadatas=metadatas
            )
            
            self.logger.info(f"成功批量添加 {len(question_ids)} 个文档")
            return True
            
        except Exception as e:
            self.logger.error(f"批量添加文档失败: {e}")
            raise
    
    def update_document(
        self,
        question_id: str,
        content: Optional[str] = None,
        embedding: Optional[Embedding] = None,
        metadata: Optional[Metadata] = None
    ) -> bool:
        """
        更新文档
        
        Args:
            question_id: 题目ID
            content: 新的内容（可选）
            embedding: 新的向量（可选）
            metadata: 新的元数据（可选）
            
        Returns:
            bool: 更新是否成功
        """
        try:
            collection = self.get_collection()
            
            # ChromaDB 的 update 方法
            update_params: Dict[str, Any] = {"ids": [question_id]}
            
            if content is not None:
                update_params["documents"] = [content]
            if embedding is not None:
                update_params["embeddings"] = [embedding]
            if metadata is not None:
                update_params["metadatas"] = [metadata]
            
            collection.update(**update_params)
            
            self.logger.debug(f"成功更新文档: {question_id}")
            return True
            
        except Exception as e:
            self.logger.error(f"更新文档失败: {e}")
            raise
    
    def delete_document(self, question_id: str) -> bool:
        """
        删除文档
        
        Args:
            question_id: 题目ID
            
        Returns:
            bool: 删除是否成功
        """
        try:
            collection = self.get_collection()
            
            collection.delete(ids=[question_id])
            
            self.logger.debug(f"成功删除文档: {question_id}")
            return True
            
        except Exception as e:
            self.logger.error(f"删除文档失败: {e}")
            raise
    
    def delete_documents_batch(self, question_ids: List[str]) -> bool:
        """
        批量删除文档
        
        Args:
            question_ids: 题目ID列表
            
        Returns:
            bool: 删除是否成功
        """
        try:
            collection = self.get_collection()
            
            collection.delete(ids=question_ids)
            
            self.logger.info(f"成功批量删除 {len(question_ids)} 个文档")
            return True
            
        except Exception as e:
            self.logger.error(f"批量删除文档失败: {e}")
            raise
    
    def document_exists(self, question_id: str) -> bool:
        """
        检查文档是否存在
        
        Args:
            question_id: 题目ID
            
        Returns:
            bool: 文档是否存在
        """
        try:
            collection = self.get_collection()
            
            result = collection.get(ids=[question_id])
            return bool(result and result.get('ids'))
            
        except Exception as e:
            self.logger.error(f"检查文档存在性失败: {e}")
            return False
    
    # -------------------------------------------------------------------------
    # 向量检索操作
    # -------------------------------------------------------------------------
    
    def search_similar(
        self,
        query_embedding: Embedding,
        top_k: int = 10,
        where: Optional[Dict[str, Any]] = None,
        min_similarity: float = 0.0
    ) -> List[Dict[str, Any]]:
        """
        相似度检索
        
        Args:
            query_embedding: 查询向量
            top_k: 返回结果数量
            where: 元数据过滤条件
            min_similarity: 最低相似度阈值（0-1）
            
        Returns:
            List[Dict]: 检索结果列表
        """
        try:
            collection = self.get_collection()
            
            # 执行查询
            results = collection.query(
                query_embeddings=[query_embedding],
                n_results=top_k,
                where=where,
                include=["documents", "metadatas", "distances"]
            )
            
            # 格式化结果
            formatted_results = []
            if not results:
                return formatted_results

            ids_list = (results.get("ids") or [[]])[0]
            distances_list = (results.get("distances") or [[]])[0]
            documents_list = (results.get("documents") or [[]])[0]
            metadatas_list = (results.get("metadatas") or [[]])[0]
            
            if ids_list:
                for i in range(len(ids_list)):
                    # 计算相似度分数（距离转换为相似度）
                    distance = distances_list[i]
                    
                    # 余弦距离转相似度：similarity = 1 - distance
                    if self.distance_metric == "cosine":
                        similarity_score = 1.0 - distance
                    else:
                        # L2 距离的简单归一化
                        similarity_score = 1.0 / (1.0 + distance)
                    
                    # 应用相似度阈值
                    if similarity_score < min_similarity:
                        continue
                    
                    formatted_results.append({
                        "question_id": ids_list[i],
                        "content": documents_list[i],
                        "metadata": metadatas_list[i],
                        "similarity_score": similarity_score,
                        "distance": distance
                    })
            
            self.logger.debug(f"相似度检索完成，返回 {len(formatted_results)} 个结果")
            return formatted_results
            
        except Exception as e:
            self.logger.error(f"相似度检索失败: {e}")
            raise
    
    def get_by_ids(self, question_ids: List[str]) -> List[Dict[str, Any]]:
        """
        根据ID列表获取文档
        
        Args:
            question_ids: 题目ID列表
            
        Returns:
            List[Dict]: 文档列表
        """
        try:
            collection = self.get_collection()
            
            results = collection.get(
                ids=question_ids,
                include=["documents", "metadatas"]
            )
            
            formatted_results = []
            if not results:
                return formatted_results

            ids_list = results.get("ids") or []
            documents_list = results.get("documents") or []
            metadatas_list = results.get("metadatas") or []

            if ids_list:
                for i in range(len(ids_list)):
                    formatted_results.append({
                        "question_id": ids_list[i],
                        "content": documents_list[i],
                        "metadata": metadatas_list[i]
                    })
            
            return formatted_results
            
        except Exception as e:
            self.logger.error(f"根据ID获取文档失败: {e}")
            raise
    
    def filter_by_metadata(
        self,
        where: Dict[str, Any],
        limit: int = 100
    ) -> List[Dict[str, Any]]:
        """
        根据元数据过滤文档
        
        Args:
            where: 过滤条件
            limit: 返回结果数量限制
            
        Returns:
            List[Dict]: 符合条件的文档列表
        """
        try:
            collection = self.get_collection()
            
            results = collection.get(
                where=where,
                limit=limit,
                include=["documents", "metadatas"]
            )
            
            formatted_results = []
            if not results:
                return formatted_results

            ids_list = results.get("ids") or []
            documents_list = results.get("documents") or []
            metadatas_list = results.get("metadatas") or []

            if ids_list:
                for i in range(len(ids_list)):
                    formatted_results.append({
                        "question_id": ids_list[i],
                        "content": documents_list[i],
                        "metadata": metadatas_list[i]
                    })
            
            return formatted_results
            
        except Exception as e:
            self.logger.error(f"元数据过滤失败: {e}")
            raise
    
    # -------------------------------------------------------------------------
    # 元数据更新操作
    # -------------------------------------------------------------------------
    
    def update_metadata(self, question_id: str, metadata: Metadata) -> bool:
        """
        仅更新元数据（不更新向量和内容）
        
        Args:
            question_id: 题目ID
            metadata: 新的元数据
            
        Returns:
            bool: 更新是否成功
        """
        try:
            collection = self.get_collection()
            
            collection.update(
                ids=[question_id],
                metadatas=[metadata]
            )
            
            self.logger.debug(f"成功更新文档元数据: {question_id}")
            return True
            
        except Exception as e:
            self.logger.error(f"更新文档元数据失败: {e}")
            raise
    
    def update_metadata_batch(
        self,
        question_ids: List[str],
        metadatas: Metadatas
    ) -> bool:
        """
        批量更新元数据
        
        Args:
            question_ids: 题目ID列表
            metadatas: 元数据列表
            
        Returns:
            bool: 更新是否成功
        """
        try:
            if len(question_ids) != len(metadatas):
                raise ValueError("批量更新时，ID列表和元数据列表长度必须相同")
            
            collection = self.get_collection()
            
            collection.update(
                ids=question_ids,
                metadatas=metadatas
            )
            
            self.logger.info(f"成功批量更新 {len(question_ids)} 个文档的元数据")
            return True
            
        except Exception as e:
            self.logger.error(f"批量更新元数据失败: {e}")
            raise
    
    # -------------------------------------------------------------------------
    # 集合管理操作
    # -------------------------------------------------------------------------
    
    def reset_collection(self) -> bool:
        """
        重置 Collection（删除所有数据）
        
        Returns:
            bool: 重置是否成功
        """
        try:
            client = self.connect()
            
            # 删除旧的 Collection
            try:
                client.delete_collection(name=self.collection_name)
                self.logger.warning(f"已删除 Collection: {self.collection_name}")
            except Exception:
                pass  # Collection 可能不存在
            
            # 重新创建
            self._collection = None
            self.initialize_collection()
            
            self.logger.info("Collection 重置成功")
            return True
            
        except Exception as e:
            self.logger.error(f"重置 Collection 失败: {e}")
            raise
    
    def get_all_ids(self) -> List[str]:
        """
        获取所有文档ID
        
        Returns:
            List[str]: 所有题目ID列表
        """
        try:
            collection = self.get_collection()
            
            # ChromaDB get() 不带参数会返回所有文档
            results = collection.get()
            
            return results['ids'] if results and results.get('ids') else []
            
        except Exception as e:
            self.logger.error(f"获取所有文档ID失败: {e}")
            raise
    
    def close(self):
        """关闭客户端连接并释放资源"""
        if self._client:
            try:
                # 先清除collection引用
                self._collection = None

                # 调用 reset 会清空所有数据并释放文件句柄
                self._client.reset()
                self.logger.info("ChromaDB 客户端已重置")

                # 显式删除客户端对象
                del self._client

                # 多次触发垃圾回收以确保文件句柄释放
                import gc
                for _ in range(3):
                    gc.collect()

                self.logger.info("ChromaDB 客户端已关闭，文件句柄已释放")
            except Exception as e:
                self.logger.error(f"关闭 ChromaDB 客户端时出错: {e}")
            finally:
                self._client = None
                self.logger.debug("ChromaDB 客户端引用已清除")
