"""
Embedding服务测试模块

测试Embedding API集成功能：
- API调用
- 批量向量化
- 缓存机制
- 重试机制
- 熔断器
"""

import pytest
import time
from unittest.mock import Mock, patch, MagicMock
from typing import List

from src.services.embedding_service import (
    EmbeddingService,
    EmbeddingProvider,
    CircuitBreaker
)
from src.utils.cache import EmbeddingCache, LRUCache
from src.core.config import Config


class TestLRUCache:
    """LRU缓存测试"""
    
    def test_cache_set_get(self):
        """测试缓存的设置和获取"""
        cache = LRUCache(max_size=10)
        cache.set("key1", "value1")
        
        assert cache.get("key1") == "value1"
        assert cache.get("key2") is None
        assert cache.get("key2", "default") == "default"
    
    def test_cache_ttl(self):
        """测试缓存TTL"""
        cache = LRUCache(max_size=10, ttl=1)
        cache.set("key1", "value1")
        
        # 立即获取应该成功
        assert cache.get("key1") == "value1"
        
        # 等待超过TTL
        time.sleep(1.1)
        assert cache.get("key1") is None
    
    def test_cache_eviction(self):
        """测试缓存淘汰"""
        cache = LRUCache(max_size=2)
        cache.set("key1", "value1")
        cache.set("key2", "value2")
        cache.set("key3", "value3")  # 应该淘汰key1
        
        assert cache.get("key1") is None
        assert cache.get("key2") == "value2"
        assert cache.get("key3") == "value3"
    
    def test_cache_lru_order(self):
        """测试LRU顺序"""
        cache = LRUCache(max_size=2)
        cache.set("key1", "value1")
        cache.set("key2", "value2")
        
        # 访问key1，使其成为最近使用
        cache.get("key1")
        
        # 添加key3，应该淘汰key2
        cache.set("key3", "value3")
        
        assert cache.get("key1") == "value1"
        assert cache.get("key2") is None
        assert cache.get("key3") == "value3"
    
    def test_cache_stats(self):
        """测试缓存统计"""
        cache = LRUCache(max_size=10)
        cache.set("key1", "value1")
        
        cache.get("key1")  # 命中
        cache.get("key2")  # 未命中
        
        stats = cache.get_stats()
        assert stats.hits == 1
        assert stats.misses == 1
        assert stats.total_requests == 2
        assert stats.hit_rate == 0.5


class TestEmbeddingCache:
    """Embedding缓存测试"""
    
    def test_embedding_cache(self):
        """测试Embedding缓存"""
        cache = EmbeddingCache(max_size=10)
        
        text = "测试文本"
        embedding = [0.1, 0.2, 0.3]
        
        cache.set_embedding(text, embedding, model="test-model")
        result = cache.get_embedding(text, model="test-model")
        
        assert result == embedding
    
    def test_batch_embeddings(self):
        """测试批量缓存"""
        cache = EmbeddingCache(max_size=10)
        
        embeddings = {
            "text1": [0.1, 0.2],
            "text2": [0.3, 0.4]
        }
        
        cache.set_batch_embeddings(embeddings, model="test-model")
        
        results = cache.get_batch_embeddings(["text1", "text2", "text3"], model="test-model")
        assert results["text1"] == [0.1, 0.2]
        assert results["text2"] == [0.3, 0.4]
        assert results["text3"] is None


class TestCircuitBreaker:
    """熔断器测试"""
    
    def test_circuit_breaker_success(self):
        """测试成功调用"""
        breaker = CircuitBreaker(failure_threshold=3)
        
        def success_func():
            return "success"
        
        result = breaker.call(success_func)
        assert result == "success"
        assert breaker.get_state()["state"] == "CLOSED"
    
    def test_circuit_breaker_failure(self):
        """测试失败调用"""
        breaker = CircuitBreaker(failure_threshold=3)
        
        def fail_func():
            raise Exception("error")
        
        # 连续失败直到打开熔断器
        for i in range(3):
            with pytest.raises(Exception):
                breaker.call(fail_func)
        
        # 熔断器应该已打开
        assert breaker.get_state()["state"] == "OPEN"
        
        # 后续调用应该直接被拒绝
        with pytest.raises(Exception, match="熔断器已打开"):
            breaker.call(fail_func)
    
    def test_circuit_breaker_half_open(self):
        """测试半开状态"""
        breaker = CircuitBreaker(failure_threshold=2, timeout=1)
        
        def fail_func():
            raise Exception("error")
        
        # 打开熔断器
        for i in range(2):
            with pytest.raises(Exception):
                breaker.call(fail_func)
        
        assert breaker.get_state()["state"] == "OPEN"
        
        # 等待超时
        time.sleep(1.1)
        
        # 下一次调用应该进入半开状态
        def success_func():
            return "success"
        
        result = breaker.call(success_func)
        assert result == "success"
        assert breaker.get_state()["state"] == "CLOSED"


class TestEmbeddingService:
    """Embedding服务测试"""
    
    @pytest.fixture
    def mock_config(self):
        """模拟配置"""
        config = Mock(spec=Config)
        config.get = Mock(side_effect=self._mock_config_get)
        return config
    
    def _mock_config_get(self, key, default=None):
        """模拟配置获取"""
        config_values = {
            "EMBEDDING_API_KEY": "test-key",
            "EMBEDDING_API_ENDPOINT": "https://api.openai.com/v1/embeddings",
            "EMBEDDING_MODEL": "text-embedding-3-small",
            "embedding.api_key": "test-key",
            "embedding.api_endpoint": "https://api.openai.com/v1/embeddings",
            "embedding.model": "text-embedding-3-small",
            "embedding.timeout": 30,
            "embedding.max_retries": 3,
            "embedding.retry_delay": 1,
            "embedding.batch_size": 20,
            "embedding.max_batch_size": 50,
            "cache.embedding_cache.enabled": True,
            "cache.embedding_cache.max_size": 100,
            "cache.embedding_cache.ttl": 3600,
            "embedding.circuit_breaker": {
                "failure_threshold": 5,
                "timeout": 60,
                "half_open_max_calls": 3
            }
        }
        return config_values.get(key, default)
    
    @patch('requests.post')
    def test_embed_text(self, mock_post, mock_config):
        """测试单文本向量化"""
        # 模拟API响应
        mock_response = Mock()
        mock_response.json.return_value = {
            "data": [{"embedding": [0.1, 0.2, 0.3]}]
        }
        mock_response.raise_for_status = Mock()
        mock_post.return_value = mock_response
        
        # 创建服务
        service = EmbeddingService(config_manager=mock_config)
        
        # 调用向量化
        embedding = service.embed_text("测试文本")
        
        assert embedding == [0.1, 0.2, 0.3]
        assert mock_post.called
    
    @patch('requests.post')
    def test_embed_batch(self, mock_post, mock_config):
        """测试批量向量化"""
        # 模拟API响应
        mock_response = Mock()
        mock_response.json.return_value = {
            "data": [
                {"embedding": [0.1, 0.2]},
                {"embedding": [0.3, 0.4]}
            ]
        }
        mock_response.raise_for_status = Mock()
        mock_post.return_value = mock_response
        
        # 创建服务
        service = EmbeddingService(config_manager=mock_config)
        
        # 调用批量向量化
        texts = ["文本1", "文本2"]
        embeddings = service.embed_batch(texts)
        
        assert len(embeddings) == 2
        assert embeddings[0] == [0.1, 0.2]
        assert embeddings[1] == [0.3, 0.4]
    
    @patch('requests.post')
    def test_cache_hit(self, mock_post, mock_config):
        """测试缓存命中"""
        # 模拟API响应
        mock_response = Mock()
        mock_response.json.return_value = {
            "data": [{"embedding": [0.1, 0.2, 0.3]}]
        }
        mock_response.raise_for_status = Mock()
        mock_post.return_value = mock_response
        
        # 创建服务
        service = EmbeddingService(config_manager=mock_config)
        
        # 第一次调用
        embedding1 = service.embed_text("测试文本")
        assert mock_post.call_count == 1
        
        # 第二次调用应该命中缓存
        embedding2 = service.embed_text("测试文本")
        assert mock_post.call_count == 1  # 不应该再次调用API
        assert embedding1 == embedding2
    
    @patch('requests.post')
    def test_retry_on_failure(self, mock_post, mock_config):
        """测试失败重试"""
        # 前两次失败，第三次成功
        mock_response_fail = Mock()
        mock_response_fail.raise_for_status.side_effect = Exception("API Error")
        
        mock_response_success = Mock()
        mock_response_success.json.return_value = {
            "data": [{"embedding": [0.1, 0.2, 0.3]}]
        }
        mock_response_success.raise_for_status = Mock()
        
        mock_post.side_effect = [
            mock_response_fail,
            mock_response_fail,
            mock_response_success
        ]
        
        # 创建服务
        service = EmbeddingService(config_manager=mock_config)
        service.retry_delay = 0.1  # 减少测试时间
        
        # 调用应该在第三次成功
        embedding = service.embed_text("测试文本", use_cache=False)
        assert embedding == [0.1, 0.2, 0.3]
        assert mock_post.call_count == 3
    
    def test_provider_detection(self, mock_config):
        """测试API提供商检测"""
        service = EmbeddingService(config_manager=mock_config)
        assert service.provider == EmbeddingProvider.OPENAI
    
    def test_cache_stats(self, mock_config):
        """测试缓存统计"""
        service = EmbeddingService(config_manager=mock_config)
        stats = service.get_cache_stats()
        
        assert stats is not None
        assert "max_size" in stats
        assert "current_size" in stats
        assert "stats" in stats


@pytest.mark.skipif(
    True,  # 默认跳过，需要真实API密钥时可以启用
    reason="需要真实的API密钥和网络连接"
)
class TestEmbeddingServiceIntegration:
    """集成测试（需要真实API）"""
    
    def test_real_api_call(self):
        """测试真实API调用"""
        # 需要配置真实的API密钥
        service = EmbeddingService()
        
        text = "这是一个测试文本"
        embedding = service.embed_text(text)
        
        assert isinstance(embedding, list)
        assert len(embedding) > 0
        assert all(isinstance(x, float) for x in embedding)
    
    def test_real_batch_call(self):
        """测试真实批量API调用"""
        service = EmbeddingService()
        
        texts = ["文本1", "文本2", "文本3"]
        embeddings = service.embed_batch(texts)
        
        assert len(embeddings) == 3
        assert all(isinstance(e, list) for e in embeddings)


if __name__ == "__main__":
    pytest.main([__file__, "-v"])
