"""
阶段3验证脚本 - Embedding API集成

验证内容：
1. LRU缓存功能
2. Embedding缓存功能
3. 熔断器功能
4. Embedding服务基本功能
"""

import sys
import time
from pathlib import Path

# 添加项目根目录到路径
project_root = Path(__file__).parent
sys.path.insert(0, str(project_root))

from src.utils.cache import LRUCache, EmbeddingCache, QueryCache
from src.services.embedding_service import CircuitBreaker, EmbeddingService, EmbeddingProvider
from src.core.config import Config


def test_lru_cache():
    """测试LRU缓存"""
    print("\n" + "=" * 60)
    print("测试1: LRU缓存功能")
    print("=" * 60)
    
    cache = LRUCache(max_size=3, ttl=2)
    
    # 测试设置和获取
    cache.set("key1", "value1")
    cache.set("key2", "value2")
    cache.set("key3", "value3")
    
    assert cache.get("key1") == "value1", "缓存获取失败"
    print("✅ 缓存设置和获取正常")
    
    # 测试LRU淘汰
    # 访问key2和key3，使key1成为最旧的
    cache.get("key2")
    cache.get("key3")
    cache.set("key4", "value4")  # 应该淘汰key1
    assert cache.get("key1") is None, "LRU淘汰失败"
    assert cache.get("key2") == "value2", "key2不应被淘汰"
    assert cache.get("key3") == "value3", "key3不应被淘汰"
    assert cache.get("key4") == "value4", "新键未添加"
    print("✅ LRU淘汰机制正常")
    
    # 测试TTL
    cache.set("temp", "temporary", ttl=1)
    assert cache.get("temp") == "temporary", "TTL设置失败"
    time.sleep(1.1)
    assert cache.get("temp") is None, "TTL过期失败"
    print("✅ TTL机制正常")
    
    # 测试统计
    stats = cache.get_stats()
    print(f"✅ 缓存统计: 命中率={stats.hit_rate:.2%}, 总请求={stats.total_requests}")
    
    return True


def test_embedding_cache():
    """测试Embedding缓存"""
    print("\n" + "=" * 60)
    print("测试2: Embedding缓存功能")
    print("=" * 60)
    
    cache = EmbeddingCache(max_size=10, ttl=3600)
    
    # 测试单个缓存
    text1 = "这是测试文本1"
    embedding1 = [0.1, 0.2, 0.3, 0.4]
    cache.set_embedding(text1, embedding1, model="test-model")
    
    result = cache.get_embedding(text1, model="test-model")
    assert result == embedding1, "Embedding缓存失败"
    print("✅ 单文本Embedding缓存正常")
    
    # 测试批量缓存
    embeddings = {
        "文本1": [0.1, 0.2],
        "文本2": [0.3, 0.4],
        "文本3": [0.5, 0.6]
    }
    cache.set_batch_embeddings(embeddings, model="test-model")
    
    results = cache.get_batch_embeddings(["文本1", "文本2", "文本4"], model="test-model")
    assert results["文本1"] == [0.1, 0.2], "批量缓存获取失败"
    assert results["文本2"] == [0.3, 0.4], "批量缓存获取失败"
    assert results["文本4"] is None, "不存在的键应返回None"
    print("✅ 批量Embedding缓存正常")
    
    # 测试缓存信息
    info = cache.get_info()
    print(f"✅ 缓存信息: 大小={info['current_size']}/{info['max_size']}, "
          f"命中率={info['stats']['hit_rate']}")
    
    return True


def test_query_cache():
    """测试查询缓存"""
    print("\n" + "=" * 60)
    print("测试3: 查询缓存功能")
    print("=" * 60)
    
    cache = QueryCache(max_size=10, ttl=1800)
    
    # 测试查询缓存
    query = "如何学习Python"
    filters = {"category": "编程", "difficulty": "medium"}
    result = ["结果1", "结果2", "结果3"]
    
    cache.set_query_result(query, result, filters=filters, top_k=10)
    cached_result = cache.get_query_result(query, filters=filters, top_k=10)
    
    assert cached_result == result, "查询缓存失败"
    print("✅ 查询结果缓存正常")
    
    # 测试不同参数的缓存
    cached_result2 = cache.get_query_result(query, filters=filters, top_k=5)
    assert cached_result2 is None, "不同参数应该缓存未命中"
    print("✅ 查询参数区分正常")
    
    return True


def test_circuit_breaker():
    """测试熔断器"""
    print("\n" + "=" * 60)
    print("测试4: 熔断器功能")
    print("=" * 60)
    
    breaker = CircuitBreaker(failure_threshold=3, timeout=2, half_open_max_calls=2)
    
    # 测试成功调用
    def success_func():
        return "success"
    
    result = breaker.call(success_func)
    assert result == "success", "成功调用失败"
    assert breaker.get_state()["state"] == "CLOSED", "初始状态应该是CLOSED"
    print("✅ 熔断器成功调用正常")
    
    # 测试失败调用
    def fail_func():
        raise Exception("模拟失败")
    
    for i in range(3):
        try:
            breaker.call(fail_func)
        except Exception:
            pass
    
    state = breaker.get_state()
    assert state["state"] == "OPEN", "连续失败后应该打开熔断器"
    assert state["failures"] >= 3, "失败计数错误"
    print(f"✅ 熔断器打开正常 (失败次数: {state['failures']})")
    
    # 测试熔断器拒绝
    try:
        breaker.call(success_func)
        assert False, "熔断器打开时应该拒绝请求"
    except Exception as e:
        assert "熔断器已打开" in str(e), "错误消息不正确"
    print("✅ 熔断器拒绝请求正常")
    
    # 测试半开状态
    print("⏳ 等待熔断器超时...")
    time.sleep(2.1)
    
    result = breaker.call(success_func)
    assert result == "success", "半开状态调用失败"
    assert breaker.get_state()["state"] == "CLOSED", "成功后应该关闭熔断器"
    print("✅ 熔断器半开和恢复正常")
    
    return True


def test_embedding_service():
    """测试Embedding服务"""
    print("\n" + "=" * 60)
    print("测试5: Embedding服务功能")
    print("=" * 60)
    
    # 使用模拟配置
    from unittest.mock import Mock
    
    config = Mock(spec=Config)
    
    def mock_get(key, default=None):
        config_values = {
            "EMBEDDING_API_KEY": "test-key",
            "EMBEDDING_API_ENDPOINT": "https://api.openai.com/v1/embeddings",
            "EMBEDDING_MODEL": "text-embedding-3-small",
            "embedding.api_endpoint": "https://api.openai.com/v1/embeddings",
            "embedding.model": "text-embedding-3-small",
            "embedding.timeout": 30,
            "embedding.max_retries": 3,
            "embedding.retry_delay": 1,
            "embedding.batch_size": 20,
            "embedding.max_batch_size": 50,
            "cache.embedding_cache.enabled": True,
            "cache.embedding_cache.max_size": 100,
            "cache.embedding_cache.ttl": 3600,
            "embedding.circuit_breaker": {
                "failure_threshold": 5,
                "timeout": 60,
                "half_open_max_calls": 3
            }
        }
        return config_values.get(key, default)
    
    config.get = mock_get
    
    # 创建服务
    service = EmbeddingService(config_manager=config)
    
    # 检查配置
    assert service.model == "text-embedding-3-small", "模型配置错误"
    assert service.provider == EmbeddingProvider.OPENAI, "提供商检测错误"
    print(f"✅ 服务初始化正常 (提供商: {service.provider.value}, 模型: {service.model})")
    
    # 检查缓存
    assert service.cache is not None, "缓存未启用"
    cache_stats = service.get_cache_stats()
    assert cache_stats is not None, "缓存统计获取失败"
    print(f"✅ 缓存功能正常 (最大容量: {cache_stats['max_size']})")
    
    # 检查熔断器
    cb_state = service.get_circuit_breaker_state()
    assert cb_state["state"] == "CLOSED", "熔断器初始状态错误"
    print(f"✅ 熔断器功能正常 (状态: {cb_state['state']})")
    
    # 测试缓存清空
    service.clear_cache()
    print("✅ 缓存清空功能正常")
    
    # 测试熔断器重置
    service.reset_circuit_breaker()
    print("✅ 熔断器重置功能正常")
    
    return True


def main():
    """运行所有验证测试"""
    print("\n" + "=" * 60)
    print("阶段3验证 - Embedding API集成")
    print("=" * 60)
    
    tests = [
        ("LRU缓存", test_lru_cache),
        ("Embedding缓存", test_embedding_cache),
        ("查询缓存", test_query_cache),
        ("熔断器", test_circuit_breaker),
        ("Embedding服务", test_embedding_service),
    ]
    
    passed = 0
    failed = 0
    
    for name, test_func in tests:
        try:
            if test_func():
                passed += 1
            else:
                failed += 1
                print(f"❌ {name}测试失败")
        except Exception as e:
            failed += 1
            print(f"❌ {name}测试异常: {e}")
            import traceback
            traceback.print_exc()
    
    print("\n" + "=" * 60)
    print("验证结果汇总")
    print("=" * 60)
    print(f"✅ 通过: {passed}/{len(tests)}")
    print(f"❌ 失败: {failed}/{len(tests)}")
    
    if failed == 0:
        print("\n🎉 阶段3所有功能验证通过！")
        print("\n已完成功能：")
        print("  ✅ LRU缓存实现（线程安全、TTL支持、统计功能）")
        print("  ✅ Embedding专用缓存")
        print("  ✅ 查询结果缓存")
        print("  ✅ 熔断器实现（防雪崩、自动恢复）")
        print("  ✅ Embedding服务框架")
        print("  ✅ 多API提供商支持")
        print("  ✅ 重试机制（指数退避）")
        print("\n下一步：阶段4 - 题目管理服务")
        return 0
    else:
        print("\n⚠️  存在失败的测试，请检查")
        return 1


if __name__ == "__main__":
    exit(main())
