"""
Embedding API 服务模块

提供文本向量化功能，支持：
- 多种Embedding API（OpenAI、通义千问、智谱AI等）
- 批量向量化
- 重试和熔断机制
- LRU缓存
- 错误处理
"""

import time
import os
import requests
import random
from typing import List, Dict, Any, Optional, Tuple
from enum import Enum
from dataclasses import dataclass
import threading

from ..core.config import Config
from ..core.logger import get_logger
from ..utils.cache import EmbeddingCache


logger = get_logger(__name__)


class EmbeddingProvider(Enum):
    """Embedding API 提供商"""
    OPENAI = "openai"
    QWEN = "qwen"  # 通义千问
    ZHIPU = "zhipu"  # 智谱AI
    CUSTOM = "custom"


@dataclass
class CircuitBreakerState:
    """熔断器状态"""
    failures: int = 0
    last_failure_time: float = 0
    state: str = "CLOSED"  # CLOSED, OPEN, HALF_OPEN
    half_open_calls: int = 0


class CircuitBreaker:
    """
    熔断器实现
    
    在API连续失败时打开熔断器，防止雪崩效应
    """
    
    def __init__(self, failure_threshold: int = 5, timeout: int = 60, 
                 half_open_max_calls: int = 3):
        """
        初始化熔断器
        
        Args:
            failure_threshold: 失败阈值，超过此值打开熔断器
            timeout: 熔断器打开时长（秒），之后进入半开状态
            half_open_max_calls: 半开状态允许的最大调用次数
        """
        self.failure_threshold = failure_threshold
        self.timeout = timeout
        self.half_open_max_calls = half_open_max_calls
        self._state = CircuitBreakerState()
        self._lock = threading.Lock()
    
    def call(self, func, *args, **kwargs):
        """
        通过熔断器调用函数
        
        Args:
            func: 要调用的函数
            *args: 位置参数
            **kwargs: 关键字参数
            
        Returns:
            函数返回值
            
        Raises:
            Exception: 熔断器打开时抛出异常
        """
        with self._lock:
            # 检查熔断器状态
            if self._state.state == "OPEN":
                # 检查是否可以进入半开状态
                if time.time() - self._state.last_failure_time >= self.timeout:
                    self._state.state = "HALF_OPEN"
                    self._state.half_open_calls = 0
                    logger.info("熔断器进入半开状态")
                else:
                    raise Exception("熔断器已打开，拒绝请求")
            
            # 半开状态检查
            if self._state.state == "HALF_OPEN":
                if self._state.half_open_calls >= self.half_open_max_calls:
                    raise Exception("熔断器半开状态，请求次数已达上限")
                self._state.half_open_calls += 1
        
        # 执行函数
        try:
            result = func(*args, **kwargs)
            self._on_success()
            return result
        except Exception as e:
            self._on_failure()
            raise e
    
    def _on_success(self):
        """调用成功的处理"""
        with self._lock:
            if self._state.state == "HALF_OPEN":
                # 半开状态调用成功，关闭熔断器
                self._state.state = "CLOSED"
                self._state.failures = 0
                logger.info("熔断器已关闭")
            else:
                # 重置失败计数
                self._state.failures = 0
    
    def _on_failure(self):
        """调用失败的处理"""
        with self._lock:
            self._state.failures += 1
            self._state.last_failure_time = time.time()
            
            if self._state.state == "HALF_OPEN":
                # 半开状态失败，重新打开熔断器
                self._state.state = "OPEN"
                logger.warning("熔断器重新打开")
            elif self._state.failures >= self.failure_threshold:
                # 失败次数超过阈值，打开熔断器
                self._state.state = "OPEN"
                logger.warning(f"熔断器已打开，连续失败次数: {self._state.failures}")
    
    def get_state(self) -> Dict[str, Any]:
        """获取熔断器状态"""
        with self._lock:
            return {
                "state": self._state.state,
                "failures": self._state.failures,
                "last_failure_time": self._state.last_failure_time,
                "half_open_calls": self._state.half_open_calls
            }
    
    def reset(self):
        """重置熔断器"""
        with self._lock:
            self._state = CircuitBreakerState()
            logger.info("熔断器已重置")


class EmbeddingService:
    """
    Embedding向量化服务
    
    提供统一的文本向量化接口，支持多种API提供商
    """
    
    def __init__(self, config_manager: Optional[Config] = None):
        """
        初始化Embedding服务
        
        Args:
            config_manager: 配置管理器，None则创建新实例
        """
        self.config = config_manager or Config()
        
        # 加载配置
        # 优先从环境变量获取，其次从配置文件获取
        self.api_key = (
            os.getenv("EMBEDDING_API_KEY") or
            self.config.get("embedding.api_key", "")
        )
        self.api_endpoint = (
            os.getenv("EMBEDDING_API_ENDPOINT") or
            self.config.get("embedding.api_endpoint", "")
        )
        self.model = (
            os.getenv("EMBEDDING_MODEL") or
            self.config.get("embedding.model", "text-embedding-v2")
        )
        self.timeout = self.config.get("embedding.timeout", 30)
        self.max_retries = self.config.get("embedding.max_retries", 3)
        self.retry_delay = self.config.get("embedding.retry_delay", 1)
        self.batch_size = self.config.get("embedding.batch_size", 20)
        self.max_batch_size = self.config.get("embedding.max_batch_size", 50)
        
        # 初始化缓存
        cache_enabled = self.config.get("cache.embedding_cache.enabled", True)
        cache_max_size = self.config.get("cache.embedding_cache.max_size", 1000)
        cache_ttl = self.config.get("cache.embedding_cache.ttl", 3600)
        
        self.cache = EmbeddingCache(
            max_size=cache_max_size,
            ttl=cache_ttl
        ) if cache_enabled else None
        
        # 初始化熔断器
        cb_config = self.config.get("embedding.circuit_breaker", {})
        self.circuit_breaker = CircuitBreaker(
            failure_threshold=cb_config.get("failure_threshold", 5),
            timeout=cb_config.get("timeout", 60),
            half_open_max_calls=cb_config.get("half_open_max_calls", 3)
        )
        
        # 检测API提供商
        self.provider = self._detect_provider()

        # 检查是否启用模拟模式
        self.mock_mode = os.getenv("MOCK_EMBEDDING", "").lower() == "true"

        # 模拟模式使用固定向量维度
        self.mock_embedding_dim = 1536

        if self.mock_mode:
            logger.warning("Embedding服务运行在模拟模式下，将返回随机向量（仅用于测试）")
        else:
            logger.info(f"Embedding服务初始化完成，提供商: {self.provider.value}, 模型: {self.model}")
    
    def _detect_provider(self) -> EmbeddingProvider:
        """
        根据配置检测API提供商
        
        Returns:
            API提供商枚举
        """
        endpoint = self.api_endpoint.lower() if self.api_endpoint else ""
        if "openai" in endpoint:
            return EmbeddingProvider.OPENAI
        elif "dashscope" in endpoint or "aliyun" in endpoint:
            return EmbeddingProvider.QWEN
        elif "zhipu" in endpoint:
            return EmbeddingProvider.ZHIPU
        else:
            return EmbeddingProvider.CUSTOM

    def _generate_mock_embedding(self) -> List[float]:
        """生成模拟向量（用于测试）"""
        # 生成固定维度的高斯随机向量
        embedding = [random.gauss(0, 1) for _ in range(self.mock_embedding_dim)]
        # 归一化
        norm = sum(x * x for x in embedding) ** 0.5
        return [x / norm for x in embedding]

    def _generate_mock_embeddings(self, count: int) -> List[List[float]]:
        """生成多个模拟向量（用于测试）"""
        return [self._generate_mock_embedding() for _ in range(count)]

    def embed_text(self, text: str, use_cache: bool = True) -> List[float]:
        """
        将单个文本向量化

        Args:
            text: 输入文本
            use_cache: 是否使用缓存

        Returns:
            向量

        Raises:
            ValueError: 文本为空
            Exception: API调用失败
        """
        if not text or not text.strip():
            raise ValueError("文本不能为空")

        # 模拟模式
        if self.mock_mode:
            logger.debug(f"模拟模式：生成随机向量，文本长度: {len(text)}")
            return self._generate_mock_embedding()

        # 检查缓存
        if use_cache and self.cache:
            cached = self.cache.get_embedding(text, self.model)
            if cached is not None:
                logger.debug(f"从缓存获取向量，文本长度: {len(text)}")
                return cached

        # 调用API
        embeddings = self.embed_batch([text], use_cache=False)
        embedding = embeddings[0]

        # 缓存结果
        if use_cache and self.cache:
            self.cache.set_embedding(text, embedding, self.model)

        return embedding
    
    def embed_batch(self, texts: List[str], use_cache: bool = True) -> List[List[float]]:
        """
        批量向量化文本

        Args:
            texts: 文本列表
            use_cache: 是否使用缓存

        Returns:
            向量列表，顺序与输入对应

        Raises:
            ValueError: 输入为空或超过最大批次大小
            Exception: API调用失败
        """
        if not texts:
            raise ValueError("文本列表不能为空")

        if len(texts) > self.max_batch_size:
            raise ValueError(f"批次大小超过限制: {len(texts)} > {self.max_batch_size}")

        # 模拟模式
        if self.mock_mode:
            logger.debug(f"模拟模式：生成 {len(texts)} 个随机向量")
            return self._generate_mock_embeddings(len(texts))

        # 检查缓存
        uncached_texts = []
        uncached_indices = []
        results: List[List[float]] = [[] for _ in range(len(texts))]

        if use_cache and self.cache:
            for i, text in enumerate(texts):
                cached = self.cache.get_embedding(text, self.model)
                if cached is not None:
                    results[i] = cached
                else:
                    uncached_texts.append(text)
                    uncached_indices.append(i)

            if uncached_texts:
                logger.debug(f"缓存命中: {len(texts) - len(uncached_texts)}/{len(texts)}")
        else:
            uncached_texts = texts
            uncached_indices = list(range(len(texts)))

        # 如果全部命中缓存，直接返回
        if not uncached_texts:
            return results  # type: ignore

        # 调用API获取未缓存的向量
        try:
            new_embeddings = self._call_api_with_retry(uncached_texts)

            # 填充结果并缓存
            for i, embedding in zip(uncached_indices, new_embeddings):
                results[i] = embedding
                if use_cache and self.cache:
                    self.cache.set_embedding(texts[i], embedding, self.model)

            return results

        except Exception as e:
            logger.error(f"批量向量化失败: {e}")
            raise
    
    def _call_api_with_retry(self, texts: List[str]) -> List[List[float]]:
        """
        带重试的API调用
        
        Args:
            texts: 文本列表
            
        Returns:
            向量列表
            
        Raises:
            Exception: 重试次数耗尽后仍失败
        """
        last_error = None
        
        for attempt in range(self.max_retries):
            try:
                # 通过熔断器调用API
                return self.circuit_breaker.call(self._call_api, texts)
                
            except Exception as e:
                last_error = e
                logger.warning(f"API调用失败 (尝试 {attempt + 1}/{self.max_retries}): {e}")
                
                if attempt < self.max_retries - 1:
                    # 指数退避
                    delay = self.retry_delay * (2 ** attempt)
                    logger.info(f"等待 {delay} 秒后重试...")
                    time.sleep(delay)
        
        # 所有重试都失败
        raise Exception(f"API调用失败，已重试 {self.max_retries} 次: {last_error}")
    
    def _call_api(self, texts: List[str]) -> List[List[float]]:
        """
        实际的API调用
        
        Args:
            texts: 文本列表
            
        Returns:
            向量列表
            
        Raises:
            Exception: API调用失败
        """
        # 构建请求
        headers, payload = self._build_request(texts)
        
        # 发送请求
        start_time = time.time()
        try:
            response = requests.post(
                self.api_endpoint,
                headers=headers,
                json=payload,
                timeout=self.timeout
            )
            response.raise_for_status()
            
        except requests.exceptions.RequestException as e:
            logger.error(f"API请求失败: {e}")
            raise Exception(f"API请求失败: {e}")
        
        # 解析响应
        try:
            embeddings = self._parse_response(response.json())
            duration = time.time() - start_time
            logger.info(f"API调用成功，耗时: {duration:.2f}秒，文本数: {len(texts)}")
            return embeddings
            
        except Exception as e:
            logger.error(f"响应解析失败: {e}")
            raise Exception(f"响应解析失败: {e}")
    
    def _build_request(self, texts: List[str]) -> Tuple[Dict[str, str], Dict[str, Any]]:
        """
        构建API请求
        
        Args:
            texts: 文本列表
            
        Returns:
            (headers, payload)元组
        """
        headers = {
            "Content-Type": "application/json"
        }
        
        if self.provider == EmbeddingProvider.OPENAI:
            headers["Authorization"] = f"Bearer {self.api_key}"
            payload = {
                "input": texts,
                "model": self.model
            }
            
        elif self.provider == EmbeddingProvider.QWEN:
            headers["Authorization"] = f"Bearer {self.api_key}"
            payload = {
                "model": self.model,
                "input": {
                    "texts": texts
                }
            }
            
        elif self.provider == EmbeddingProvider.ZHIPU:
            headers["Authorization"] = f"Bearer {self.api_key}"
            payload = {
                "model": self.model,
                "input": texts
            }
            
        else:
            # 通用格式
            headers["Authorization"] = f"Bearer {self.api_key}"
            payload = {
                "model": self.model,
                "input": texts
            }
        
        return headers, payload
    
    def _parse_response(self, response_data: Dict[str, Any]) -> List[List[float]]:
        """
        解析API响应
        
        Args:
            response_data: 响应JSON数据
            
        Returns:
            向量列表
            
        Raises:
            Exception: 响应格式错误
        """
        try:
            if self.provider == EmbeddingProvider.OPENAI:
                embeddings = [item["embedding"] for item in response_data["data"]]
                
            elif self.provider == EmbeddingProvider.QWEN:
                embeddings = [item["embedding"] for item in response_data["output"]["embeddings"]]
                
            elif self.provider == EmbeddingProvider.ZHIPU:
                embeddings = [item["embedding"] for item in response_data["data"]]
                
            else:
                # 尝试通用格式
                if "data" in response_data:
                    embeddings = [item["embedding"] for item in response_data["data"]]
                elif "embeddings" in response_data:
                    embeddings = response_data["embeddings"]
                else:
                    raise Exception("未知的响应格式")
            
            return embeddings
            
        except (KeyError, TypeError) as e:
            raise Exception(f"响应格式错误: {e}")
    
    def get_cache_stats(self) -> Optional[Dict[str, Any]]:
        """
        获取缓存统计信息
        
        Returns:
            缓存统计字典或None
        """
        if self.cache:
            return self.cache.get_info()
        return None
    
    def get_circuit_breaker_state(self) -> Dict[str, Any]:
        """
        获取熔断器状态
        
        Returns:
            熔断器状态字典
        """
        return self.circuit_breaker.get_state()
    
    def clear_cache(self):
        """清空缓存"""
        if self.cache:
            self.cache.clear()
            logger.info("Embedding缓存已清空")
    
    def reset_circuit_breaker(self):
        """重置熔断器"""
        self.circuit_breaker.reset()
