"""
题目检索服务

提供题目检索功能，包括：
- 语义检索：基于向量相似度的智能检索
- 关键词检索：基于SQLite FTS的全文检索
- 混合检索：结合语义和关键词的混合检索模式

通过整合ChromaDB和SQLite的优势，提供灵活多样的检索体验。
"""

import sys
import os

# 将项目根目录添加到 sys.path
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..'))
if project_root not in sys.path:
    sys.path.insert(0, project_root)

from typing import Dict, List, Optional, Any, Tuple, Union
from datetime import datetime
import json

from src.core.logger import get_logger
from src.database.database_manager import DatabaseManager
from src.services.embedding_service import EmbeddingService
from src.database.models import QuestionSearchFilter


class SearchService:
    """
    题目检索服务

    提供多种检索方式：
    1. 语义检索 - 使用向量相似度查找语义上相似的题目
    2. 关键词检索 - 使用全文检索查找包含特定关键词的题目
    3. 混合检索 - 结合语义检索和关键词检索的结果
    """

    def __init__(
        self,
        db_manager: DatabaseManager,
        embedding_service: EmbeddingService,
        logger=None
    ):
        """
        初始化检索服务

        Args:
            db_manager: 数据库管理器实例
            embedding_service: Embedding服务实例
            logger: 日志记录器实例
        """
        self.db_manager = db_manager
        self.embedding_service = embedding_service
        self.logger = logger or get_logger()

    # -------------------------------------------------------------------------
    # 语义检索
    # -------------------------------------------------------------------------

    def search_by_semantic(
        self,
        query: str,
        top_k: int = 10,
        filters: Optional[QuestionSearchFilter] = None,
        min_similarity: float = 0.0,
        include_metadata: bool = True
    ) -> Dict[str, Any]:
        """
        语义检索 - 基于向量相似度的智能检索

        将查询文本转换为向量，然后在ChromaDB中查找最相似的题目。
        适用于自然语言查询，如"查找关于数据结构栈的题目"。

        Args:
            query: 检索查询文本
            top_k: 返回结果数量
            filters: 过滤条件
            min_similarity: 最低相似度阈值（0-1）
            include_metadata: 是否包含元数据

        Returns:
            Dict: 检索结果，包含题目列表和元信息

        Raises:
            ValueError: 查询参数无效
            Exception: 检索过程出错
        """
        try:
            # 验证参数
            if not query or not query.strip():
                raise ValueError("查询文本不能为空")

            if top_k < 1:
                top_k = 1
            if top_k > 100:
                top_k = 100

            if min_similarity < 0 or min_similarity > 1:
                self.logger.warning(f"无效的相似度阈值: {min_similarity}，使用默认值 0.0")
                min_similarity = 0.0

            # 转换查询为向量
            self.logger.info(f"正在生成查询向量: '{query}'")
            query_embedding = self.embedding_service.embed_text(query)

            # 准备过滤条件
            where_filter = self._build_chroma_filter(filters) if filters else None

            # 执行向量检索
            self.logger.info(f"执行语义检索，top_k={top_k}")
            raw_results = self.db_manager.chroma_dao.search_similar(
                query_embedding=query_embedding, # type: ignore
                top_k=top_k,
                where=where_filter,
                min_similarity=min_similarity
            )

            # 增强结果数据
            enriched_results = []
            for result in raw_results:
                # 从SQLite获取完整题目数据
                question_data = self.db_manager.sqlite_dao.get_question(
                    result['question_id']
                )

                if question_data:
                    # 合并数据
                    enriched_result = {
                        'question_id': result['question_id'],
                        'title': question_data.get('title', ''),
                        'content': question_data.get('content', ''),
                        'question_type': question_data.get('question_type', ''),
                        'category': question_data.get('category', ''),
                        'difficulty': question_data.get('difficulty', ''),
                        'tags': question_data.get('tags', []),
                        'answer': question_data.get('answer', ''),
                        'explanation': question_data.get('explanation', ''),
                        'status': question_data.get('status', ''),
                        'created_at': question_data.get('created_at', ''),
                        'updated_at': question_data.get('updated_at', ''),
                        'search_score': result.get('similarity_score', 0.0),
                        'search_type': 'semantic',
                        'query': query
                    }

                    # 如果不需要元数据，则移除answer和explanation
                    if not include_metadata:
                        enriched_result.pop('answer', None)
                        enriched_result.pop('explanation', None)

                    enriched_results.append(enriched_result)

            # 构建返回结果
            result = {
                'query': query,
                'search_type': 'semantic',
                'total_results': len(enriched_results),
                'results': enriched_results,
                'metadata': {
                    'top_k': top_k,
                    'min_similarity': min_similarity,
                    'has_filters': filters is not None,
                    'filters_applied': filters.model_dump() if filters else None
                },
                'timestamp': datetime.now().isoformat()
            }

            self.logger.info(
                f"语义检索完成: 查询='{query}', "
                f"返回{len(enriched_results)}个结果"
            )

            return result

        except ValueError as e:
            self.logger.error(f"语义检索参数验证失败: {e}")
            raise
        except Exception as e:
            self.logger.error(f"语义检索失败: {e}")
            raise

    # -------------------------------------------------------------------------
    # 关键词检索
    # -------------------------------------------------------------------------

    def search_by_keyword(
        self,
        query: str,
        top_k: int = 10,
        filters: Optional[QuestionSearchFilter] = None,
        search_fields: Optional[List[str]] = None,
        match_mode: str = 'OR',
        include_metadata: bool = True
    ) -> Dict[str, Any]:
        """
        关键词检索 - 基于SQLite全文检索

        在SQLite数据库中执行全文检索，查找包含指定关键词的题目。
        适用于精确关键词查询，如"查找包含'二叉树'的题目"。

        Args:
            query: 检索查询文本（支持多个关键词）
            top_k: 返回结果数量
            filters: 过滤条件
            search_fields: 搜索字段列表，默认为['content', 'title', 'tags']
            match_mode: 匹配模式，'OR'或'AND'
            include_metadata: 是否包含答案和解析

        Returns:
            Dict: 检索结果

        Raises:
            ValueError: 查询参数无效
            Exception: 检索过程出错
        """
        try:
            # 验证参数
            if not query or not query.strip():
                raise ValueError("查询文本不能为空")

            if top_k < 1:
                top_k = 1
            if top_k > 100:
                top_k = 100

            if match_mode not in ['OR', 'AND']:
                self.logger.warning(f"无效的匹配模式: {match_mode}，使用默认值 OR")
                match_mode = 'OR'

            # 设置默认搜索字段
            if search_fields is None:
                search_fields = ['content', 'title', 'tags']

            # 执行关键词检索
            self.logger.info(
                f"执行关键词检索: 查询='{query}', "
                f"字段={search_fields}, 模式={match_mode}"
            )
            results = self.db_manager.search_questions_by_keyword(
                query=query,
                top_k=top_k,
                filters=filters,
                search_fields=search_fields,
                match_mode=match_mode
            )

            # 增强结果数据
            enriched_results = []
            for result in results:
                enriched_result = {
                    'question_id': result['question_id'],
                    'title': result.get('title', ''),
                    'content': result.get('content', ''),
                    'question_type': result.get('question_type', ''),
                    'category': result.get('category', ''),
                    'difficulty': result.get('difficulty', ''),
                    'tags': result.get('tags', []),
                    'answer': result.get('answer', ''),
                    'explanation': result.get('explanation', ''),
                    'status': result.get('status', ''),
                    'created_at': result.get('created_at', ''),
                    'updated_at': result.get('updated_at', ''),
                    'search_score': result.get('search_score', 0.0),
                    'search_type': 'keyword',
                    'query': query,
                    'matched_fields': result.get('matched_fields', [])
                }

                # 如果不需要元数据
                if not include_metadata:
                    enriched_result.pop('answer', None)
                    enriched_result.pop('explanation', None)

                enriched_results.append(enriched_result)

            # 构建返回结果
            result = {
                'query': query,
                'search_type': 'keyword',
                'total_results': len(enriched_results),
                'results': enriched_results,
                'metadata': {
                    'top_k': top_k,
                    'search_fields': search_fields,
                    'match_mode': match_mode,
                    'has_filters': filters is not None,
                    'filters_applied': filters.model_dump() if filters else None
                },
                'timestamp': datetime.now().isoformat()
            }

            self.logger.info(
                f"关键词检索完成: 查询='{query}', "
                f"返回{len(enriched_results)}个结果"
            )

            return result

        except ValueError as e:
            self.logger.error(f"关键词检索参数验证失败: {e}")
            raise
        except Exception as e:
            self.logger.error(f"关键词检索失败: {e}")
            raise

    # -------------------------------------------------------------------------
    # 混合检索
    # -------------------------------------------------------------------------

    def search_hybrid(
        self,
        query: str,
        top_k: int = 10,
        filters: Optional[QuestionSearchFilter] = None,
        semantic_weight: float = 0.6,
        keyword_weight: float = 0.4,
        min_similarity: float = 0.0,
        include_metadata: bool = True
    ) -> Dict[str, Any]:
        """
        混合检索 - 结合语义检索和关键词检索

        同时执行语义检索和关键词检索，然后将结果进行融合排序。
        权重可通过semantic_weight和keyword_weight参数调整。

        Args:
            query: 检索查询文本
            top_k: 返回结果数量
            filters: 过滤条件
            semantic_weight: 语义检索权重 (0-1)
            keyword_weight: 关键词检索权重 (0-1)
            min_similarity: 语义检索最低相似度阈值
            include_metadata: 是否包含答案和解析

        Returns:
            Dict: 检索结果

        Raises:
            ValueError: 参数无效
            Exception: 检索过程出错
        """
        try:
            # 验证参数
            if not query or not query.strip():
                raise ValueError("查询文本不能为空")

            if top_k < 1:
                top_k = 1
            if top_k > 100:
                top_k = 100

            # 验证权重
            if semantic_weight < 0 or semantic_weight > 1:
                raise ValueError("语义检索权重必须在0-1之间")

            if keyword_weight < 0 or keyword_weight > 1:
                raise ValueError("关键词检索权重必须在0-1之间")

            # 归一化权重
            total_weight = semantic_weight + keyword_weight
            if total_weight == 0:
                self.logger.warning("权重总和为0，使用默认权重")
                semantic_weight = 0.6
                keyword_weight = 0.4
            else:
                semantic_weight /= total_weight
                keyword_weight /= total_weight

            self.logger.info(
                f"执行混合检索: 查询='{query}', "
                f"语义权重={semantic_weight:.2f}, "
                f"关键词权重={keyword_weight:.2f}"
            )

            # 并发执行两种检索
            semantic_results = None
            keyword_results = None
            semantic_error = None
            keyword_error = None

            # 执行语义检索
            try:
                semantic_results = self.search_by_semantic(
                    query=query,
                    top_k=top_k * 2,  # 扩大检索范围以获得更多候选
                    filters=filters,
                    min_similarity=min_similarity,
                    include_metadata=False
                )
            except Exception as e:
                semantic_error = str(e)
                self.logger.error(f"语义检索失败: {e}")

            # 执行关键词检索
            try:
                keyword_results = self.search_by_keyword(
                    query=query,
                    top_k=top_k * 2,
                    filters=filters,
                    include_metadata=False
                )
            except Exception as e:
                keyword_error = str(e)
                self.logger.error(f"关键词检索失败: {e}")

            # 如果两种检索都失败，抛出异常
            if semantic_error and keyword_error:
                raise Exception(f"语义检索和关键词检索都失败: {semantic_error}, {keyword_error}")

            # 合并结果
            combined_results = self._combine_search_results(
                semantic_results,
                keyword_results,
                semantic_weight,
                keyword_weight,
                include_metadata=include_metadata
            )

            # 构建返回结果
            result = {
                'query': query,
                'search_type': 'hybrid',
                'total_results': len(combined_results),
                'results': combined_results[:top_k],  # 限制返回数量
                'metadata': {
                    'top_k': top_k,
                    'semantic_weight': semantic_weight,
                    'keyword_weight': keyword_weight,
                    'min_similarity': min_similarity,
                    'semantic_available': semantic_results is not None,
                    'keyword_available': keyword_results is not None,
                    'semantic_error': semantic_error,
                    'keyword_error': keyword_error,
                    'has_filters': filters is not None,
                    'filters_applied': filters.model_dump() if filters else None
                },
                'timestamp': datetime.now().isoformat()
            }

            self.logger.info(
                f"混合检索完成: 查询='{query}', "
                f"返回{len(combined_results[:top_k])}个结果"
            )

            return result

        except ValueError as e:
            self.logger.error(f"混合检索参数验证失败: {e}")
            raise
        except Exception as e:
            self.logger.error(f"混合检索失败: {e}")
            raise

    # -------------------------------------------------------------------------
    # 辅助方法
    # -------------------------------------------------------------------------

    def _build_chroma_filter(self, filters: QuestionSearchFilter) -> Dict[str, Any]:
        """
        构建ChromaDB过滤条件

        Args:
            filters: 过滤条件

        Returns:
            Dict: ChromaDB格式的过滤条件
        """
        where_filter = {}

        # 添加元数据过滤条件
        if filters:
            # 题型过滤
            if filters.question_type:
                where_filter['question_type'] = filters.question_type

            # 分类过滤
            if filters.category:
                where_filter['category'] = filters.category

            # 难度过滤
            if filters.difficulty:
                where_filter['difficulty'] = filters.difficulty

            # 标签过滤（使用$contains操作符）
            if filters.tags:
                # ChromaDB支持元数据数组的包含检查
                where_filter['tags'] = {'$contains': filters.tags}

            # 状态过滤
            if filters.status:
                where_filter['status'] = filters.status

        return where_filter

    def _combine_search_results(
        self,
        semantic_results: Optional[Dict[str, Any]],
        keyword_results: Optional[Dict[str, Any]],
        semantic_weight: float,
        keyword_weight: float,
        include_metadata: bool = False
    ) -> List[Dict[str, Any]]:
        """
        合并语义检索和关键词检索的结果

        Args:
            semantic_results: 语义检索结果
            keyword_results: 关键词检索结果
            semantic_weight: 语义检索权重
            keyword_weight: 关键词检索权重
            include_metadata: 是否包含元数据

        Returns:
            List[Dict]: 合并后的结果列表
        """
        # 创建题目ID到结果的映射
        question_scores = {}
        question_data = {}

        # 处理语义检索结果
        if semantic_results and 'results' in semantic_results:
            for result in semantic_results['results']:
                question_id = result['question_id']
                question_scores[question_id] = {
                    'semantic_score': result.get('search_score', 0.0),
                    'keyword_score': 0.0,
                    'combined_score': 0.0
                }
                question_data[question_id] = result

        # 处理关键词检索结果
        if keyword_results and 'results' in keyword_results:
            for result in keyword_results['results']:
                question_id = result['question_id']

                # 如果已存在语义检索结果，添加关键词分数
                if question_id in question_scores:
                    question_scores[question_id]['keyword_score'] = result.get('search_score', 0.0)
                else:
                    # 如果只有关键词检索结果
                    question_scores[question_id] = {
                        'semantic_score': 0.0,
                        'keyword_score': result.get('search_score', 0.0),
                        'combined_score': 0.0
                    }
                    question_data[question_id] = result

        # 计算综合分数
        for question_id, scores in question_scores.items():
            scores['combined_score'] = (
                scores['semantic_score'] * semantic_weight +
                scores['keyword_score'] * keyword_weight
            )

        # 按综合分数排序
        sorted_question_ids = sorted(
            question_scores.keys(),
            key=lambda qid: question_scores[qid]['combined_score'],
            reverse=True
        )

        # 构建最终结果
        combined_results = []
        for question_id in sorted_question_ids:
            if question_id in question_data:
                result = question_data[question_id].copy()

                # 添加检索分数信息
                result['search_score'] = question_scores[question_id]['combined_score']
                result['semantic_score'] = question_scores[question_id]['semantic_score']
                result['keyword_score'] = question_scores[question_id]['keyword_score']
                result['search_type'] = 'hybrid'
                result['query'] = semantic_results.get('query') or keyword_results.get('query') # type: ignore

                # 如果不需要元数据，移除答案和解析
                if not include_metadata:
                    result.pop('answer', None)
                    result.pop('explanation', None)

                combined_results.append(result)

        return combined_results

    # -------------------------------------------------------------------------
    # 统计和管理
    # -------------------------------------------------------------------------

    def get_search_statistics(self) -> Dict[str, Any]:
        """
        获取检索服务的统计信息

        Returns:
            Dict: 统计信息
        """
        try:
            stats = {
                'embedding_cache_stats': self.embedding_service.get_cache_stats(),
                'database_stats': self.db_manager.get_statistics(),
                'timestamp': datetime.now().isoformat()
            }

            return stats

        except Exception as e:
            self.logger.error(f"获取检索统计信息失败: {e}")
            raise
