"""
SQLite 数据访问对象（DAO）

提供 SQLite 数据库的统一访问接口，封装所有数据库操作。
包含连接管理、CRUD操作、事务管理等功能。
"""

import sqlite3
from contextlib import contextmanager
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Any

from src.core.logger import get_logger
from src.database.models import QuestionModel, TagModel, QuestionSearchFilter


class SQLiteDAO:
    """
    SQLite 数据访问对象
    
    负责管理 SQLite 数据库连接和所有数据操作。
    使用 WAL 模式提高并发性能。
    """
    
    def __init__(self, db_path: str, logger=None):
        """
        初始化 SQLite DAO
        
        Args:
            db_path: 数据库文件路径
            logger: 日志记录器实例
        """
        self.db_path = db_path
        self.logger = logger or get_logger()
        self._ensure_db_directory()
        self._connection: Optional[sqlite3.Connection] = None
        self._use_persistent_connection = (db_path == ":memory:")  # 内存数据库使用持久连接
        
    def _ensure_db_directory(self):
        """确保数据库目录存在"""
        db_file = Path(self.db_path)
        db_file.parent.mkdir(parents=True, exist_ok=True)
        
    def connect(self) -> sqlite3.Connection:
        """
        创建数据库连接
        
        Returns:
            sqlite3.Connection: 数据库连接对象
        """
        try:
            conn = sqlite3.connect(
                self.db_path,
                timeout=30.0,  # 30秒超时
                check_same_thread=False  # 允许多线程使用
            )
            # 启用 WAL 模式提高并发性能
            conn.execute("PRAGMA journal_mode=WAL")
            # 启用外键约束
            conn.execute("PRAGMA foreign_keys=ON")
            # 设置行工厂为字典模式
            conn.row_factory = sqlite3.Row
            
            self.logger.debug(f"成功连接到数据库: {self.db_path}")
            return conn
            
        except sqlite3.Error as e:
            self.logger.error(f"数据库连接失败: {e}")
            raise
    
    @contextmanager
    def get_connection(self):
        """
        获取数据库连接的上下文管理器
        
        Yields:
            sqlite3.Connection: 数据库连接
        """
        if self._use_persistent_connection:
            # 内存数据库使用持久连接
            if self._connection is None:
                self._connection = self.connect()
            yield self._connection
        else:
            # 文件数据库每次创建新连接
            conn = self.connect()
            try:
                yield conn
            finally:
                conn.close()
    
    @contextmanager
    def transaction(self):
        """
        事务上下文管理器
        
        Yields:
            sqlite3.Connection: 数据库连接
        """
        with self.get_connection() as conn:
            try:
                yield conn
                conn.commit()
                self.logger.debug("事务提交成功")
            except Exception as e:
                conn.rollback()
                self.logger.error(f"事务回滚: {e}")
                raise
    
    def initialize_schema(self):
        """
        初始化数据库表结构
        
        读取并执行 schema.sql 文件中的SQL语句。
        """
        try:
            schema_file = Path(__file__).parent / "schema.sql"
            if not schema_file.exists():
                raise FileNotFoundError(f"Schema文件不存在: {schema_file}")
            
            with open(schema_file, "r", encoding="utf-8") as f:
                schema_sql = f.read()
            
            with self.get_connection() as conn:
                conn.executescript(schema_sql)
                self.logger.info("数据库表结构初始化成功")
                
        except Exception as e:
            self.logger.error(f"初始化数据库表结构失败: {e}")
            raise
    
    def check_schema_initialized(self) -> bool:
        """
        检查数据库表结构是否已初始化
        
        Returns:
            bool: True表示已初始化，False表示未初始化
        """
        try:
            with self.get_connection() as conn:
                cursor = conn.execute(
                    "SELECT name FROM sqlite_master WHERE type='table' AND name='questions'"
                )
                return cursor.fetchone() is not None
        except sqlite3.Error:
            return False
    
    # -------------------------------------------------------------------------
    # 题目CRUD操作
    # -------------------------------------------------------------------------
    
    def create_question(self, question: QuestionModel) -> bool:
        """
        创建题目
        
        Args:
            question: 题目数据模型
            
        Returns:
            bool: 创建是否成功
        """
        try:
            sql = """
                INSERT INTO questions (
                    question_id, title, content, question_type, category, difficulty,
                    status, answer, explanation, source, source_url, points,
                    usage_count, correct_rate, created_at, updated_at, created_by
                ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
            """
            
            with self.transaction() as conn:
                conn.execute(sql, (
                    question.question_id,
                    question.title,
                    question.content,
                    question.question_type,
                    question.category,
                    question.difficulty,
                    question.status,
                    question.answer,
                    question.explanation,
                    question.source,
                    question.source_url,
                    question.points,
                    question.usage_count,
                    question.correct_rate,
                    question.created_at,
                    question.updated_at,
                    question.created_by
                ))
                
                # 处理标签关联
                if question.tags:
                    self._associate_tags(conn, question.question_id, question.tags)
                
            self.logger.info(f"成功创建题目: {question.question_id}")
            return True
            
        except sqlite3.IntegrityError as e:
            self.logger.error(f"题目已存在或违反唯一性约束: {e}")
            raise ValueError(f"题目已存在: {question.question_id}")
        except Exception as e:
            self.logger.error(f"创建题目失败: {e}")
            raise
    
    def get_question(self, question_id: str) -> Optional[Dict[str, Any]]:
        """
        获取题目详情
        
        Args:
            question_id: 题目ID
            
        Returns:
            Optional[Dict]: 题目数据字典，如果不存在则返回None
        """
        try:
            sql = "SELECT * FROM v_questions_with_tags WHERE question_id = ?"
            
            with self.get_connection() as conn:
                cursor = conn.execute(sql, (question_id,))
                row = cursor.fetchone()
                
                if row:
                    data = dict(row)
                    # 解析标签字符串为列表
                    if data.get('tags'):
                        data['tags'] = data['tags'].split(',')
                    else:
                        data['tags'] = []
                    return data
                return None
                
        except Exception as e:
            self.logger.error(f"获取题目失败: {e}")
            raise
    
    def update_question(self, question_id: str, updates: Dict[str, Any]) -> bool:
        """
        更新题目
        
        Args:
            question_id: 题目ID
            updates: 要更新的字段字典
            
        Returns:
            bool: 更新是否成功
        """
        try:
            if not updates:
                return False
            
            # 构建UPDATE语句
            set_clause = ", ".join([f"{k} = ?" for k in updates.keys() if k != 'tags'])
            values = [v for k, v in updates.items() if k != 'tags']
            values.append(question_id)
            
            sql = f"UPDATE questions SET {set_clause}, updated_at = datetime('now') WHERE question_id = ?"
            
            with self.transaction() as conn:
                cursor = conn.execute(sql, values)
                
                if cursor.rowcount == 0:
                    raise ValueError(f"题目不存在: {question_id}")
                
                # 处理标签更新
                if 'tags' in updates:
                    self._update_tags(conn, question_id, updates['tags'])
                
            self.logger.info(f"成功更新题目: {question_id}")
            return True
            
        except Exception as e:
            self.logger.error(f"更新题目失败: {e}")
            raise
    
    def delete_question(self, question_id: str, soft_delete: bool = False) -> bool:
        """
        删除题目
        
        Args:
            question_id: 题目ID
            soft_delete: 是否软删除（改为已归档状态）
            
        Returns:
            bool: 删除是否成功
        """
        try:
            if soft_delete:
                # 软删除：更新状态为已归档
                sql = "UPDATE questions SET status = '已归档', updated_at = datetime('now') WHERE question_id = ?"
            else:
                # 硬删除：物理删除记录（外键级联会自动删除关联）
                sql = "DELETE FROM questions WHERE question_id = ?"
            
            with self.transaction() as conn:
                cursor = conn.execute(sql, (question_id,))
                
                if cursor.rowcount == 0:
                    raise ValueError(f"题目不存在: {question_id}")
            
            delete_type = "软删除" if soft_delete else "硬删除"
            self.logger.info(f"成功{delete_type}题目: {question_id}")
            return True
            
        except Exception as e:
            self.logger.error(f"删除题目失败: {e}")
            raise
    
    def list_questions(
        self,
        filters: Optional[QuestionSearchFilter] = None,
        page: int = 1,
        page_size: int = 20,
        sort_by: str = "created_at",
        sort_order: str = "DESC"
    ) -> Tuple[List[Dict[str, Any]], int]:
        """
        查询题目列表（分页）
        
        Args:
            filters: 过滤条件
            page: 页码（从1开始）
            page_size: 每页数量
            sort_by: 排序字段
            sort_order: 排序方向（ASC/DESC）
            
        Returns:
            Tuple[List[Dict], int]: (题目列表, 总数量)
        """
        try:
            # 构建WHERE子句
            where_conditions = []
            params = []
            
            if filters:
                if filters.category:
                    where_conditions.append("category = ?")
                    params.append(filters.category)
                if filters.difficulty:
                    where_conditions.append("difficulty = ?")
                    params.append(filters.difficulty)
                if filters.question_type:
                    where_conditions.append("question_type = ?")
                    params.append(filters.question_type)
                if filters.status:
                    where_conditions.append("status = ?")
                    params.append(filters.status)
            
            where_clause = f"WHERE {' AND '.join(where_conditions)}" if where_conditions else ""
            
            # 查询总数
            count_sql = f"SELECT COUNT(*) FROM questions {where_clause}"
            
            # 查询数据
            offset = (page - 1) * page_size
            data_sql = f"""
                SELECT * FROM v_questions_with_tags 
                {where_clause}
                ORDER BY {sort_by} {sort_order}
                LIMIT ? OFFSET ?
            """
            
            with self.get_connection() as conn:
                # 获取总数
                cursor = conn.execute(count_sql, params)
                total = cursor.fetchone()[0]
                
                # 获取数据
                cursor = conn.execute(data_sql, params + [page_size, offset])
                rows = cursor.fetchall()
                
                items = []
                for row in rows:
                    data = dict(row)
                    if data.get('tags'):
                        data['tags'] = data['tags'].split(',')
                    else:
                        data['tags'] = []
                    items.append(data)
                
                return items, total
                
        except Exception as e:
            self.logger.error(f"查询题目列表失败: {e}")
            raise
    
    # -------------------------------------------------------------------------
    # 标签管理操作
    # -------------------------------------------------------------------------
    
    def _get_or_create_tag(self, conn: sqlite3.Connection, tag_name: str) -> int:
        """
        获取或创建标签
        
        Args:
            conn: 数据库连接
            tag_name: 标签名称
            
        Returns:
            int: 标签ID
        """
        # 尝试获取已存在的标签
        cursor = conn.execute("SELECT tag_id FROM tags WHERE tag_name = ?", (tag_name,))
        row = cursor.fetchone()
        
        if row:
            return row[0]
        
        # 创建新标签
        cursor = conn.execute("INSERT INTO tags (tag_name) VALUES (?)", (tag_name,))
        tag_id = cursor.lastrowid
        if tag_id is None:
            raise RuntimeError("创建标签失败：未获取到tag_id")
        return tag_id
    
    def _associate_tags(self, conn: sqlite3.Connection, question_id: str, tags: List[str]):
        """
        关联题目和标签
        
        Args:
            conn: 数据库连接
            question_id: 题目ID
            tags: 标签名称列表
        """
        for tag_name in tags:
            tag_id = self._get_or_create_tag(conn, tag_name.strip())
            # 插入关联（如果已存在则忽略）
            conn.execute(
                "INSERT OR IGNORE INTO question_tags (question_id, tag_id) VALUES (?, ?)",
                (question_id, tag_id)
            )
    
    def _update_tags(self, conn: sqlite3.Connection, question_id: str, new_tags: List[str]):
        """
        更新题目的标签
        
        Args:
            conn: 数据库连接
            question_id: 题目ID
            new_tags: 新的标签列表
        """
        # 删除现有关联
        conn.execute("DELETE FROM question_tags WHERE question_id = ?", (question_id,))
        
        # 创建新关联
        if new_tags:
            self._associate_tags(conn, question_id, new_tags)
    
    def get_all_tags(self) -> List[Dict[str, Any]]:
        """
        获取所有标签
        
        Returns:
            List[Dict]: 标签列表
        """
        try:
            sql = "SELECT * FROM tags ORDER BY usage_count DESC"
            
            with self.get_connection() as conn:
                cursor = conn.execute(sql)
                return [dict(row) for row in cursor.fetchall()]
                
        except Exception as e:
            self.logger.error(f"获取标签列表失败: {e}")
            raise
    
    def get_tag_statistics(self, top_n: int = 20) -> List[Dict[str, Any]]:
        """
        获取标签统计信息
        
        Args:
            top_n: 返回前N个标签
            
        Returns:
            List[Dict]: 标签统计列表
        """
        try:
            sql = """
                SELECT t.tag_id, t.tag_name, t.usage_count,
                       GROUP_CONCAT(DISTINCT q.category) as categories
                FROM tags t
                LEFT JOIN question_tags qt ON t.tag_id = qt.tag_id
                LEFT JOIN questions q ON qt.question_id = q.question_id
                GROUP BY t.tag_id
                ORDER BY t.usage_count DESC
                LIMIT ?
            """
            
            with self.get_connection() as conn:
                cursor = conn.execute(sql, (top_n,))
                rows = cursor.fetchall()
                
                result = []
                for row in rows:
                    data = dict(row)
                    if data.get('categories'):
                        data['categories'] = data['categories'].split(',')
                    else:
                        data['categories'] = []
                    result.append(data)
                
                return result
                
        except Exception as e:
            self.logger.error(f"获取标签统计失败: {e}")
            raise
    
    # -------------------------------------------------------------------------
    # 关键词检索操作
    # -------------------------------------------------------------------------

    def search_questions_by_keyword(
        self,
        query: str,
        top_k: int = 10,
        filters: Optional[QuestionSearchFilter] = None,
        search_fields: Optional[List[str]] = None,
        match_mode: str = 'OR'
    ) -> List[Dict[str, Any]]:
        """
        关键词检索 - 基于FTS5全文检索

        Args:
            query: 检索查询文本
            top_k: 返回结果数量
            filters: 过滤条件
            search_fields: 搜索字段列表，默认为['content', 'title', 'tags']
            match_mode: 匹配模式，'OR'或'AND'

        Returns:
            List[Dict]: 检索结果列表

        Raises:
            ValueError: 查询参数无效
            Exception: 检索过程出错
        """
        try:
            # 验证参数
            if not query or not query.strip():
                raise ValueError("查询文本不能为空")

            if top_k < 1:
                top_k = 1
            if top_k > 100:
                top_k = 100

            # 设置默认搜索字段
            if search_fields is None:
                search_fields = ['content', 'title', 'tags']

            # 准备MATCH查询表达式
            keywords = query.strip().split()
            if not keywords:
                return []

            # 根据匹配模式构建查询
            if match_mode == 'AND':
                match_query = ' AND '.join([f'"{kw}"' for kw in keywords])
            else:  # OR模式
                match_query = ' OR '.join([f'"{kw}"' for kw in keywords])

            # 构建WHERE子句（过滤条件）
            where_conditions = []
            params = []

            if filters:
                if filters.category:
                    where_conditions.append("q.category = ?")
                    params.append(filters.category)
                if filters.difficulty:
                    where_conditions.append("q.difficulty = ?")
                    params.append(filters.difficulty)
                if filters.question_type:
                    where_conditions.append("q.question_type = ?")
                    params.append(filters.question_type)
                if filters.status:
                    where_conditions.append("q.status = ?")
                    params.append(filters.status)

            where_clause = ""
            if where_conditions:
                where_clause = " AND " + " AND ".join(where_conditions)

            # 构建最终SQL查询
            sql = f"""
                SELECT
                    q.*,
                    fts.rank as search_score,
                    bm25(questions_fts) as bm25_score
                FROM questions_fts fts
                JOIN questions q ON fts.question_id = q.question_id
                WHERE questions_fts MATCH ?{where_clause}
                ORDER BY fts.rank
                LIMIT ?
            """

            params = [match_query] + params + [top_k]

            with self.get_connection() as conn:
                cursor = conn.execute(sql, params)
                rows = cursor.fetchall()

                # 转换为字典格式
                results = []
                for row in rows:
                    data = dict(row)
                    # 处理标签（从字符串转换为列表）
                    if data.get('tags'):
                        data['tags'] = data['tags'].split(',') if data['tags'] else []
                    else:
                        data['tags'] = []

                    # 添加搜索相关字段
                    data['search_score'] = data.get('search_score', 0.0) or 0.0
                    data['bm25_score'] = data.get('bm25_score', 0.0) or 0.0

                    results.append(data)

            self.logger.info(
                f"关键词检索完成: 查询='{query}', "
                f"匹配模式={match_mode}, 返回{len(results)}个结果"
            )

            return results

        except sqlite3.Error as e:
            self.logger.error(f"FTS检索失败: {e}")
            raise
        except Exception as e:
            self.logger.error(f"关键词检索失败: {e}")
            raise

    def search_questions_by_keyword_enhanced(
        self,
        query: str,
        top_k: int = 10,
        filters: Optional[QuestionSearchFilter] = None,
        search_fields: Optional[List[str]] = None,
        match_mode: str = 'OR'
    ) -> List[Dict[str, Any]]:
        """
        关键词检索增强版 - 基于FTS5并返回匹配字段信息

        Args:
            query: 检索查询文本
            top_k: 返回结果数量
            filters: 过滤条件
            search_fields: 搜索字段列表，默认为['content', 'title', 'tags']
            match_mode: 匹配模式，'OR'或'AND'

        Returns:
            List[Dict]: 检索结果列表，包含matched_fields字段

        Raises:
            ValueError: 查询参数无效
            Exception: 检索过程出错
        """
        try:
            # 验证参数
            if not query or not query.strip():
                raise ValueError("查询文本不能为空")

            # 设置默认搜索字段
            if search_fields is None:
                search_fields = ['content', 'title', 'tags']

            # 准备MATCH查询表达式
            keywords = query.strip().split()
            if not keywords:
                return []

            # 构建字段特定的查询
            field_queries = []
            for field in search_fields:
                field_keywords = [f'{field}:"{kw}"' for kw in keywords]
                if match_mode == 'AND':
                    field_queries.append('(' + ' AND '.join(field_keywords) + ')')
                else:  # OR
                    field_queries.append('(' + ' OR '.join(field_keywords) + ')')

            if match_mode == 'AND':
                match_query = ' AND '.join(field_queries)
            else:
                match_query = ' OR '.join(field_queries)

            # 构建WHERE子句（过滤条件）
            where_conditions = []
            params = []

            if filters:
                if filters.category:
                    where_conditions.append("q.category = ?")
                    params.append(filters.category)
                if filters.difficulty:
                    where_conditions.append("q.difficulty = ?")
                    params.append(filters.difficulty)
                if filters.question_type:
                    where_conditions.append("q.question_type = ?")
                    params.append(filters.question_type)
                if filters.status:
                    where_conditions.append("q.status = ?")
                    params.append(filters.status)

            where_clause = ""
            if where_conditions:
                where_clause = " AND " + " AND ".join(where_conditions)

            # 构建最终SQL查询（获取匹配字段信息）
            sql = f"""
                SELECT
                    q.*,
                    fts.rank as search_score,
                    bm25(questions_fts) as bm25_score
                FROM questions_fts fts
                JOIN questions q ON fts.question_id = q.question_id
                WHERE questions_fts MATCH ?{where_clause}
                ORDER BY fts.rank
                LIMIT ?
            """

            params = [match_query] + params + [top_k]

            with self.get_connection() as conn:
                cursor = conn.execute(sql, params)
                rows = cursor.fetchall()

                # 转换为字典格式
                results = []
                for row in rows:
                    data = dict(row)
                    # 处理标签
                    if data.get('tags'):
                        data['tags'] = data['tags'].split(',') if data['tags'] else []
                    else:
                        data['tags'] = []

                    # 添加搜索相关字段
                    data['search_score'] = data.get('search_score', 0.0) or 0.0
                    data['bm25_score'] = data.get('bm25_score', 0.0) or 0.0
                    data['matched_fields'] = search_fields  # 记录搜索字段

                    results.append(data)

            self.logger.info(
                f"关键词检索增强版完成: 查询='{query}', "
                f"匹配模式={match_mode}, 搜索字段={search_fields}, "
                f"返回{len(results)}个结果"
            )

            return results

        except sqlite3.Error as e:
            self.logger.error(f"FTS增强检索失败: {e}")
            raise
        except Exception as e:
            self.logger.error(f"关键词检索增强版失败: {e}")
            raise

    # -------------------------------------------------------------------------
    # 统计查询操作
    # -------------------------------------------------------------------------
    
    def get_statistics_overview(self) -> Dict[str, Any]:
        """
        获取题库整体统计数据
        
        Returns:
            Dict: 统计数据
        """
        try:
            with self.get_connection() as conn:
                # 总体统计
                cursor = conn.execute("""
                    SELECT 
                        COUNT(*) as total_questions,
                        SUM(CASE WHEN status='已发布' THEN 1 ELSE 0 END) as published_count,
                        AVG(CASE WHEN usage_count > 0 THEN usage_count END) as avg_usage_count,
                        AVG(CASE WHEN correct_rate IS NOT NULL THEN correct_rate END) as avg_correct_rate
                    FROM questions
                """)
                overview = dict(cursor.fetchone())
                
                # 按分类统计
                cursor = conn.execute("""
                    SELECT category, COUNT(*) as count
                    FROM questions
                    GROUP BY category
                """)
                overview['by_category'] = [dict(row) for row in cursor.fetchall()]
                
                # 按难度统计
                cursor = conn.execute("""
                    SELECT difficulty, COUNT(*) as count
                    FROM questions
                    GROUP BY difficulty
                """)
                overview['by_difficulty'] = [dict(row) for row in cursor.fetchall()]
                
                # 按题型统计
                cursor = conn.execute("""
                    SELECT question_type, COUNT(*) as count
                    FROM questions
                    GROUP BY question_type
                """)
                overview['by_question_type'] = [dict(row) for row in cursor.fetchall()]
                
                # 按状态统计
                cursor = conn.execute("""
                    SELECT status, COUNT(*) as count
                    FROM questions
                    GROUP BY status
                """)
                overview['by_status'] = [dict(row) for row in cursor.fetchall()]
                
                return overview
                
        except Exception as e:
            self.logger.error(f"获取统计概览失败: {e}")
            raise
    
    def close(self):
        """关闭数据库连接"""
        if self._connection:
            self._connection.close()
            self._connection = None
            self.logger.debug("数据库连接已关闭")
