"""
批量导入服务单元测试

测试批量导入功能的各个组件：
- 导入会话管理
- 数据源对接
- 重复检测
- 断点续传
- 数据验证和清洗
"""

import pytest
import tempfile
import os
import json
from unittest.mock import Mock, patch, MagicMock
from datetime import datetime

from src.services.import_service import (
    ImportService,
    ImportSource,
    ImportStatus
)
from src.database.models import QuestionCreateDTO
from src.core.config import Config


class TestImportService:
    """批量导入服务测试"""

    @pytest.fixture
    def config(self):
        """创建测试配置"""
        config = Mock(spec=Config)
        config.get = Mock(return_value=None)
        # 为导入配置返回默认值
        config.get.side_effect = lambda key, default=None: {
            "import.batch_size": 50,
            "import.max_retries": 3,
            "import.retry_delay": 2,
            "import.duplicate_detection.enabled": True,
            "import.duplicate_detection.similarity_threshold": 0.95,
            "import.checkpoint.enabled": True,
            "import.checkpoint.file": "./data/test_checkpoint.json",
            "import.external_api.timeout": 60,
        }.get(key, default)
        config.sqlite_db_path = ":memory:"
        config.chromadb_persist_dir = ":memory:"
        return config

    @pytest.fixture
    def mock_services(self):
        """创建模拟服务"""
        db_manager = Mock()
        embedding_service = Mock()
        management_service = Mock()
        search_service = Mock()

        return {
            "db_manager": db_manager,
            "embedding_service": embedding_service,
            "management_service": management_service,
            "search_service": search_service
        }

    @pytest.fixture
    def import_service(self, config, mock_services):
        """创建导入服务实例"""
        return ImportService(
            db_manager=mock_services["db_manager"],
            embedding_service=mock_services["embedding_service"],
            management_service=mock_services["management_service"],
            search_service=mock_services["search_service"],
            config=config
        )

    # -------------------------------------------------------------------------
    # 测试导入会话管理
    # -------------------------------------------------------------------------

    def test_create_import_session_external_api(self, import_service, mock_services):
        """测试创建外部API导入会话"""
        # 准备
        source_type = ImportSource.EXTERNAL_API
        source_config = {
            "endpoint": "https://api.example.com/questions",
            "api_key": "test_key"
        }

        # 执行
        session_id = import_service.create_import_session(
            source_type=source_type,
            source_config=source_config
        )

        # 验证
        assert session_id is not None
        session = import_service.get_import_session(session_id)
        assert session["source_type"] == source_type
        assert session["status"] == ImportStatus.PENDING
        assert session["source_config"] == source_config

    def test_create_import_session_json_file(self, import_service):
        """测试创建JSON文件导入会话"""
        # 准备
        source_type = ImportSource.JSON_FILE
        source_config = {
            "file_path": "/path/to/test.json"
        }

        # 执行
        session_id = import_service.create_import_session(
            source_type=source_type,
            source_config=source_config
        )

        # 验证
        assert session_id is not None
        session = import_service.get_import_session(session_id)
        assert session["source_type"] == source_type

    def test_create_import_session_manual(self, import_service):
        """测试创建手动导入会话"""
        # 准备
        source_type = ImportSource.MANUAL
        questions = [
            {
                "content": "测试题目1",
                "title": "测试1",
                "category": "数学",
                "difficulty": "简单"
            }
        ]
        source_config = {
            "questions": questions
        }

        # 执行
        session_id = import_service.create_import_session(
            source_type=source_type,
            source_config=source_config
        )

        # 验证
        assert session_id is not None
        session = import_service.get_import_session(session_id)
        assert session["source_type"] == source_type

    def test_create_import_session_invalid_type(self, import_service):
        """测试创建无效数据源类型的会话"""
        # 执行 & 验证
        with pytest.raises(ValueError, match="不支持的数据源类型"):
            import_service.create_import_session(
                source_type="invalid_type",
                source_config={}
            )

    def test_list_import_sessions(self, import_service):
        """测试列出导入会话"""
        # 准备
        session1_id = import_service.create_import_session(
            source_type=ImportSource.MANUAL,
            source_config={"questions": []}
        )
        session2_id = import_service.create_import_session(
            source_type=ImportSource.MANUAL,
            source_config={"questions": []}
        )

        # 执行
        all_sessions = import_service.list_import_sessions()
        pending_sessions = import_service.list_import_sessions(status=ImportStatus.PENDING)

        # 验证
        assert len(all_sessions) >= 2
        assert all(s["status"] == ImportStatus.PENDING for s in pending_sessions)

    def test_get_import_session_not_exists(self, import_service):
        """测试获取不存在的会话"""
        session = import_service.get_import_session("non_existent_id")
        assert session is None

    # -------------------------------------------------------------------------
    # 测试数据验证和清洗
    # -------------------------------------------------------------------------

    def test_validate_and_clean_data_valid(self, import_service):
        """测试验证和清洗有效数据"""
        # 准备
        raw_data = {
            "content": "这是一个测试题目",
            "title": "测试题",
            "question_type": "single_choice",
            "category": "数学",
            "difficulty": "简单",
            "tags": ["数学", "代数"],
            "answer": "A",
            "explanation": "答案解析"
        }
        options = {
            "import_answers": True,
            "import_explanations": True
        }

        # 执行
        dto = import_service._validate_and_clean_data(raw_data, options)

        # 验证
        assert isinstance(dto, QuestionCreateDTO)
        assert dto.content == "这是一个测试题目"
        assert dto.title == "测试题"
        assert dto.category == "数学"
        assert dto.tags == ["数学", "代数"]

    def test_validate_and_clean_data_empty_content(self, import_service):
        """测试验证空内容数据"""
        # 准备
        raw_data = {"content": ""}

        # 执行 & 验证
        with pytest.raises(ValueError, match="题目内容不能为空"):
            import_service._validate_and_clean_data(raw_data, {})

    def test_validate_and_clean_data_missing_content(self, import_service):
        """测试验证缺少内容字段的数据"""
        # 准备
        raw_data = {"title": "测试题"}

        # 执行 & 验证
        with pytest.raises(ValueError, match="题目内容不能为空"):
            import_service._validate_and_clean_data(raw_data, {})

    def test_validate_and_clean_data_string_tags(self, import_service):
        """测试验证字符串格式标签"""
        # 准备
        raw_data = {
            "content": "测试题目",
            "tags": "数学,代数,几何"
        }

        # 执行
        dto = import_service._validate_and_clean_data(raw_data, {})

        # 验证
        assert dto.tags == ["数学", "代数", "几何"]

    def test_validate_and_clean_data_options(self, import_service):
        """测试验证选项控制答案和解析导入"""
        # 准备
        raw_data = {
            "content": "测试题目",
            "title": "测试题",
            "question_type": "单选",
            "category": "数学",
            "difficulty": "简单",
            "answer": "A",
            "explanation": "解析"
        }

        # 执行
        dto_no_answer = import_service._validate_and_clean_data(raw_data, {"import_answers": False})
        dto_no_explanation = import_service._validate_and_clean_data(raw_data, {"import_explanations": False})

        # 验证
        assert not hasattr(dto_no_answer, "answer") or dto_no_answer.answer == "" or dto_no_answer.answer is None
        assert not hasattr(dto_no_explanation, "explanation") or dto_no_explanation.explanation == "" or dto_no_explanation.explanation is None

    # -------------------------------------------------------------------------
    # 测试重复检测
    # -------------------------------------------------------------------------

    def test_check_duplicates_no_similar(self, import_service, mock_services):
        """测试未检测到重复题目"""
        # 准备
        mock_services["search_service"].search_by_semantic.return_value = ([], 0)

        question_data = QuestionCreateDTO(
            content="这是一个唯一的测试题目",
            title="测试题",
            question_type="单选",
            category="数学",
            difficulty="简单"
        )

        # 执行
        is_duplicate, info = import_service._check_duplicates(question_data)

        # 验证
        assert not is_duplicate
        assert info is None

    def test_check_duplicates_found_similar(self, import_service, mock_services):
        """测试检测到重复题目"""
        # 准备
        mock_services["search_service"].search_by_semantic.return_value = ([
            {
                "question_id": "123",
                "title": "相似题目",
                "similarity": 0.98
            }
        ], 1)

        question_data = QuestionCreateDTO(
            content="这是一个测试题目",
            title="测试题",
            question_type="单选",
            category="数学",
            difficulty="简单"
        )

        # 执行
        is_duplicate, info = import_service._check_duplicates(question_data)

        # 验证
        assert is_duplicate
        assert info is not None
        assert info["matched_question_id"] == "123"
        assert info["similarity"] == 0.98

    def test_check_duplicates_low_similarity(self, import_service, mock_services):
        """测试相似度低于阈值"""
        # 准备
        mock_services["search_service"].search_by_semantic.return_value = ([
            {
                "question_id": "123",
                "title": "不同题目",
                "similarity": 0.5
            }
        ], 1)

        question_data = QuestionCreateDTO(
            content="这是一个测试题目",
            title="测试题",
            question_type="单选",
            category="数学",
            difficulty="简单"
        )

        # 执行
        is_duplicate, info = import_service._check_duplicates(question_data)

        # 验证
        assert not is_duplicate

    # -------------------------------------------------------------------------
    # 测试批量处理
    # -------------------------------------------------------------------------

    def test_process_import_batch_validation_errors(self, import_service):
        """测试处理包含验证错误的批次"""
        # 准备
        raw_questions = [
            {
                "content": "正常题目",
                "title": "测试题1",
                "question_type": "单选",
                "category": "数学",
                "difficulty": "简单"
            },
            {"content": ""},  # 无效题目
            {
                "content": "另一个正常题目",
                "title": "测试题2",
                "question_type": "单选",
                "category": "语文",
                "difficulty": "中等"
            }
        ]
        session = {
            "options": {
                "skip_duplicates": False,
                "validate_only": False
            }
        }

        # 执行
        result = import_service._process_import_batch(session, raw_questions)

        # 验证
        assert result["processed"] == 3
        assert result["failed"] >= 1  # 至少一个失败
        assert len(result["errors"]) >= 1

    def test_process_import_batch_validate_only(self, import_service):
        """测试仅验证模式"""
        # 准备
        raw_questions = [
            {
                "content": "正常题目1",
                "title": "测试题1",
                "question_type": "单选",
                "category": "数学",
                "difficulty": "简单"
            },
            {
                "content": "正常题目2",
                "title": "测试题2",
                "question_type": "单选",
                "category": "语文",
                "difficulty": "中等"
            }
        ]
        session = {
            "options": {
                "skip_duplicates": False,
                "validate_only": True
            }
        }

        # 执行
        result = import_service._process_import_batch(session, raw_questions)

        # 验证
        assert result["processed"] == 2
        assert result["successful"] >= 0  # 验证模式下可能都成功
        assert result["failed"] >= 0

    # -------------------------------------------------------------------------
    # 测试断点续传
    # -------------------------------------------------------------------------

    def test_save_checkpoint(self, import_service):
        """测试保存断点"""
        # 准备
        session_id = "test_session"
        checkpoint_data = {"processed": 100, "page": 5}

        # 执行
        import_service._save_checkpoint(session_id, checkpoint_data)

        # 验证
        assert os.path.exists(import_service.checkpoint_file)
        with open(import_service.checkpoint_file, 'r') as f:
            checkpoints = json.load(f)
        assert session_id in checkpoints
        assert checkpoints[session_id]["processed"] == 100

    def test_clear_checkpoint(self, import_service):
        """测试清除断点"""
        # 准备
        session_id = "test_session"
        checkpoint_data = {"processed": 100}
        import_service._save_checkpoint(session_id, checkpoint_data)

        # 执行
        result = import_service.clear_checkpoint(session_id)

        # 验证
        assert result is True
        with open(import_service.checkpoint_file, 'r') as f:
            checkpoints = json.load(f)
        assert session_id not in checkpoints

    # -------------------------------------------------------------------------
    # 测试手动导入
    # -------------------------------------------------------------------------

    def test_import_from_manual(self, import_service, mock_services):
        """测试手动导入"""
        # 准备
        questions = [
            {
                "content": "测试题目1",
                "title": "测试1",
                "question_type": "单选",
                "category": "数学",
                "difficulty": "简单"
            },
            {
                "content": "测试题目2",
                "title": "测试2",
                "question_type": "单选",
                "category": "语文",
                "difficulty": "中等"
            }
        ]

        mock_services["management_service"].create_question.return_value = "question_id_123"

        session = import_service.create_import_session(
            source_type=ImportSource.MANUAL,
            source_config={"questions": questions}
        )

        # 执行
        result = import_service.start_import(session)

        # 验证
        assert result["total"] == 2
        assert result["successful"] == 2
        assert result["failed"] == 0

    # -------------------------------------------------------------------------
    # 测试导入报告
    # -------------------------------------------------------------------------

    def test_generate_import_report(self, import_service, mock_services):
        """测试生成导入报告"""
        # 准备
        questions = [{
            "content": "测试题目",
            "title": "测试题",
            "question_type": "单选",
            "category": "数学",
            "difficulty": "简单"
        }]
        mock_services["management_service"].create_question.return_value = "question_id_123"

        session_id = import_service.create_import_session(
            source_type=ImportSource.MANUAL,
            source_config={"questions": questions}
        )

        # 执行
        import_service.start_import(session_id)
        report = import_service.generate_import_report(session_id)

        # 验证
        assert report["session_id"] == session_id
        assert report["source_type"] == ImportSource.MANUAL
        assert report["status"] == ImportStatus.COMPLETED
        assert "statistics" in report
        assert "duration_seconds" in report

    def test_export_import_statistics(self, import_service, mock_services):
        """测试导出导入统计"""
        # 准备
        mock_services["management_service"].create_question.return_value = "question_id_123"

        # 记录开始前的会话数
        initial_stats = import_service.export_import_statistics()
        initial_sessions = initial_stats.get("total_statistics", {}).get("total_sessions", 0)

        # 创建多个会话
        new_sessions = []
        for i in range(3):
            questions = [{
                "content": f"测试题目{i}",
                "title": f"测试{i}",
                "question_type": "单选",
                "category": "数学",
                "difficulty": "简单"
            }]
            session_id = import_service.create_import_session(
                source_type=ImportSource.MANUAL,
                source_config={"questions": questions}
            )
            import_service.start_import(session_id)
            new_sessions.append(session_id)

        # 执行
        stats = import_service.export_import_statistics()

        # 验证
        assert "total_statistics" in stats
        assert "by_source" in stats
        # 验证新增的会话数
        assert stats["total_statistics"]["total_sessions"] >= initial_sessions + 3
        assert stats["total_statistics"]["total_successful"] >= 3

    # -------------------------------------------------------------------------
    # 测试会话状态管理
    # -------------------------------------------------------------------------

    def test_pause_import(self, import_service, mock_services):
        """测试暂停导入"""
        # 准备
        session_id = import_service.create_import_session(
            source_type=ImportSource.MANUAL,
            source_config={"questions": []}
        )

        # 模拟运行状态
        session = import_service._load_session(session_id)
        session["status"] = ImportStatus.RUNNING
        import_service._save_session(session)

        # 执行
        result = import_service.pause_import(session_id)

        # 验证
        assert result is True
        updated_session = import_service._load_session(session_id)
        assert updated_session["status"] == ImportStatus.PAUSED
        assert "checkpoint" in updated_session

    def test_resume_import(self, import_service, mock_services):
        """测试恢复导入"""
        # 准备
        questions = [{
            "content": "测试题目",
            "title": "测试题",
            "question_type": "单选",
            "category": "数学",
            "difficulty": "简单"
        }]
        mock_services["management_service"].create_question.return_value = "question_id_123"

        session_id = import_service.create_import_session(
            source_type=ImportSource.MANUAL,
            source_config={"questions": questions}
        )

        # 模拟暂停状态
        import_service.start_import(session_id)
        session = import_service._load_session(session_id)
        session["status"] = ImportStatus.PAUSED
        import_service._save_session(session)

        # 执行
        result = import_service.resume_import(session_id)

        # 验证
        assert result["total"] >= 0

    def test_cancel_import(self, import_service):
        """测试取消导入"""
        # 准备
        session_id = import_service.create_import_session(
            source_type=ImportSource.MANUAL,
            source_config={"questions": []}
        )

        # 执行
        result = import_service.cancel_import(session_id)

        # 验证
        assert result is True
        session = import_service._load_session(session_id)
        assert session["status"] == ImportStatus.CANCELLED
