"""
日志管理模块

提供统一的日志记录功能，支持文件日志、控制台日志、日志轮转等特性。
支持 request_id 追踪，便于追踪完整的请求链路。
"""

import logging
import re
import sys
import threading
import uuid
from logging.handlers import RotatingFileHandler
from pathlib import Path
from typing import Optional

# 线程本地存储，用于存储 request_id
_thread_local = threading.local()


class RequestIdFilter(logging.Filter):
    """为日志记录添加 request_id"""

    def filter(self, record):
        record.request_id = getattr(_thread_local, "request_id", "N/A")
        return True


class SensitiveDataFilter(logging.Filter):
    """敏感信息脱敏过滤器"""

    def __init__(self, sensitive_patterns=None):
        super().__init__()
        if sensitive_patterns is None:
            # 默认敏感信息模式
            sensitive_patterns = [
                (r"api[_-]?key['\"]?\s*[:=]\s*['\"]?([^'\"}\s,]+)", "api_key=***"),
                (r"password['\"]?\s*[:=]\s*['\"]?([^'\"}\s,]+)", "password=***"),
                (r"token['\"]?\s*[:=]\s*['\"]?([^'\"}\s,]+)", "token=***"),
                (r"secret['\"]?\s*[:=]\s*['\"]?([^'\"}\s,]+)", "secret=***"),
                (r"authorization:\s*bearer\s+([^\s]+)", "authorization: bearer ***"),
                (r"sk-[a-zA-Z0-9]{32,}", "sk-***"),  # OpenAI API key 格式
            ]
        self.patterns = [
            (re.compile(pattern, re.IGNORECASE), replacement)
            for pattern, replacement in sensitive_patterns
        ]

    def filter(self, record):
        """脱敏日志消息"""
        if hasattr(record, "msg"):
            msg = str(record.msg)
            for pattern, replacement in self.patterns:
                msg = pattern.sub(replacement, msg)
            record.msg = msg
        return True


class Logger:
    """日志管理器"""

    def __init__(self, name: str = "questions_mcp_server", config=None):
        """
        初始化日志管理器

        Args:
            name: 日志记录器名称
            config: 配置对象
        """
        self.name = name
        self.config = config
        self.logger = logging.getLogger(name)
        self._setup_logger()

    def _setup_logger(self):
        """设置日志记录器"""
        # 清除已有的处理器
        self.logger.handlers.clear()

        # 设置日志级别
        log_level = self._get_log_level()
        self.logger.setLevel(log_level)

        # 添加过滤器
        self.logger.addFilter(RequestIdFilter())
        self.logger.addFilter(SensitiveDataFilter())

        # 获取日志格式
        log_format = self._get_log_format()
        formatter = logging.Formatter(log_format)

        # 设置控制台处理器
        if self._is_console_enabled():
            console_handler = self._create_console_handler(formatter)
            self.logger.addHandler(console_handler)

        # 设置文件处理器
        if self._is_file_enabled():
            file_handler = self._create_file_handler(formatter)
            if file_handler:
                self.logger.addHandler(file_handler)

        # 设置错误文件处理器
        if self._is_error_file_enabled():
            error_handler = self._create_error_file_handler(formatter)
            if error_handler:
                self.logger.addHandler(error_handler)

        # 防止日志传播到根日志记录器
        self.logger.propagate = False

    def _get_log_level(self) -> int:
        """获取日志级别"""
        if self.config:
            level_str = self.config.get("logging.level", "INFO").upper()
        else:
            level_str = "INFO"

        level_map = {
            "DEBUG": logging.DEBUG,
            "INFO": logging.INFO,
            "WARNING": logging.WARNING,
            "ERROR": logging.ERROR,
            "CRITICAL": logging.CRITICAL,
        }
        return level_map.get(level_str, logging.INFO)

    def _get_log_format(self) -> str:
        """获取日志格式"""
        if self.config:
            return self.config.get(
                "logging.format",
                "%(asctime)s - %(name)s - %(levelname)s - [%(request_id)s] - %(message)s",
            )
        return "%(asctime)s - %(name)s - %(levelname)s - [%(request_id)s] - %(message)s"

    def _is_console_enabled(self) -> bool:
        """检查是否启用控制台日志"""
        if self.config:
            return self.config.get("logging.console.enabled", True)
        return True

    def _is_file_enabled(self) -> bool:
        """检查是否启用文件日志"""
        if self.config:
            return self.config.get("logging.file.enabled", True)
        return True

    def _is_error_file_enabled(self) -> bool:
        """检查是否启用错误文件日志"""
        if self.config:
            return self.config.get("logging.error_file.enabled", True)
        return True

    def _create_console_handler(self, formatter: logging.Formatter) -> logging.Handler:
        """创建控制台处理器
        
        注意：为了支持 MCP STDIO 通信，控制台日志输出到 stderr
        而不是 stdout，因为 stdout 用于 MCP JSON-RPC 消息传输
        """
        # 重要：输出到 stderr 而不是 stdout，避免干扰 MCP 协议通信
        console_handler = logging.StreamHandler(sys.stderr)

        # 获取控制台日志级别
        if self.config:
            level_str = self.config.get("logging.console.level", "INFO").upper()
            level_map = {
                "DEBUG": logging.DEBUG,
                "INFO": logging.INFO,
                "WARNING": logging.WARNING,
                "ERROR": logging.ERROR,
                "CRITICAL": logging.CRITICAL,
            }
            console_handler.setLevel(level_map.get(level_str, logging.INFO))
        else:
            console_handler.setLevel(logging.INFO)

        console_handler.setFormatter(formatter)
        return console_handler

    def _create_file_handler(self, formatter: logging.Formatter) -> Optional[logging.Handler]:
        """创建文件处理器"""
        try:
            if self.config:
                log_path = self.config.get("logging.file.path", "./logs/app.log")
                max_bytes = self.config.get("logging.file.max_bytes", 10485760)  # 10MB
                backup_count = self.config.get("logging.file.backup_count", 30)
            else:
                log_path = "./logs/app.log"
                max_bytes = 10485760
                backup_count = 30

            # 确保日志目录存在
            log_file = Path(log_path)
            log_file.parent.mkdir(parents=True, exist_ok=True)

            file_handler = RotatingFileHandler(
                log_path,
                maxBytes=max_bytes,
                backupCount=backup_count,
                encoding="utf-8",
            )
            file_handler.setLevel(logging.DEBUG)
            file_handler.setFormatter(formatter)
            return file_handler
        except Exception as e:
            # 文件日志创建失败时，只输出到控制台
            print(f"[WARNING] 无法创建文件日志: {e}，将只输出到控制台")
            return None

    def _create_error_file_handler(
        self, formatter: logging.Formatter
    ) -> Optional[logging.Handler]:
        """创建错误文件处理器"""
        try:
            if self.config:
                log_path = self.config.get("logging.error_file.path", "./logs/error.log")
                max_bytes = self.config.get("logging.error_file.max_bytes", 10485760)
                backup_count = self.config.get("logging.error_file.backup_count", 30)
            else:
                log_path = "./logs/error.log"
                max_bytes = 10485760
                backup_count = 30

            # 确保日志目录存在
            log_file = Path(log_path)
            log_file.parent.mkdir(parents=True, exist_ok=True)

            error_handler = RotatingFileHandler(
                log_path,
                maxBytes=max_bytes,
                backupCount=backup_count,
                encoding="utf-8",
            )
            error_handler.setLevel(logging.ERROR)
            error_handler.setFormatter(formatter)
            return error_handler
        except Exception as e:
            # 错误日志创建失败时，只输出到控制台
            print(f"[WARNING] 无法创建错误日志文件: {e}，错误将只输出到控制台")
            return None

    def debug(self, message: str, **kwargs):
        """记录 DEBUG 级别日志"""
        self.logger.debug(message, **kwargs)

    def info(self, message: str, **kwargs):
        """记录 INFO 级别日志"""
        self.logger.info(message, **kwargs)

    def warning(self, message: str, **kwargs):
        """记录 WARNING 级别日志"""
        self.logger.warning(message, **kwargs)

    def error(self, message: str, exc_info: bool = True, **kwargs):
        """记录 ERROR 级别日志"""
        self.logger.error(message, exc_info=exc_info, **kwargs)

    def critical(self, message: str, exc_info: bool = True, **kwargs):
        """记录 CRITICAL 级别日志"""
        self.logger.critical(message, exc_info=exc_info, **kwargs)

    def exception(self, message: str, **kwargs):
        """记录异常日志（包含堆栈信息）"""
        self.logger.exception(message, **kwargs)

    @staticmethod
    def set_request_id(request_id: Optional[str] = None):
        """
        设置当前线程的 request_id

        Args:
            request_id: 请求ID，如果为 None 则自动生成
        """
        if request_id is None:
            request_id = str(uuid.uuid4())
        _thread_local.request_id = request_id
        return request_id

    @staticmethod
    def get_request_id() -> str:
        """获取当前线程的 request_id"""
        return getattr(_thread_local, "request_id", "N/A")

    @staticmethod
    def clear_request_id():
        """清除当前线程的 request_id"""
        if hasattr(_thread_local, "request_id"):
            delattr(_thread_local, "request_id")


# 全局日志实例
_logger_instance: Optional[Logger] = None


def get_logger(name: str = "questions_mcp_server", config=None) -> Logger:
    """
    获取日志实例（单例模式）

    Args:
        name: 日志记录器名称
        config: 配置对象

    Returns:
        日志实例
    """
    global _logger_instance
    if _logger_instance is None:
        _logger_instance = Logger(name, config)
    return _logger_instance


def setup_logger(config) -> Logger:
    """
    初始化日志系统

    Args:
        config: 配置对象

    Returns:
        日志实例
    """
    global _logger_instance
    _logger_instance = Logger("questions_mcp_server", config)
    return _logger_instance


# 便捷函数
def debug(message: str, **kwargs):
    """记录 DEBUG 级别日志"""
    get_logger().debug(message, **kwargs)


def info(message: str, **kwargs):
    """记录 INFO 级别日志"""
    get_logger().info(message, **kwargs)


def warning(message: str, **kwargs):
    """记录 WARNING 级别日志"""
    get_logger().warning(message, **kwargs)


def error(message: str, exc_info: bool = True, **kwargs):
    """记录 ERROR 级别日志"""
    get_logger().error(message, exc_info=exc_info, **kwargs)


def critical(message: str, exc_info: bool = True, **kwargs):
    """记录 CRITICAL 级别日志"""
    get_logger().critical(message, exc_info=exc_info, **kwargs)


def exception(message: str, **kwargs):
    """记录异常日志（包含堆栈信息）"""
    get_logger().exception(message, **kwargs)


def set_request_id(request_id: Optional[str] = None) -> str:
    """设置当前线程的 request_id"""
    return Logger.set_request_id(request_id)


def get_request_id() -> str:
    """获取当前线程的 request_id"""
    return Logger.get_request_id()


def clear_request_id():
    """清除当前线程的 request_id"""
    Logger.clear_request_id()
