from typing import Dict, List, Optional
from uuid import UUID

from tortoise.exceptions import DoesNotExist
from tortoise.transactions import atomic

from algo_flow.app.algo.models import ModelInfo, Project
from algo_flow.cores.constant.algo import ModelFormat, ModelStatus, ModelType
from algo_flow.cores.exceptions import ResourceConflictError, ResourceNotFoundError


async def create_model(
    project_id: UUID,
    name: str,
    version: str,
    type: ModelType,
    format: ModelFormat,
    config: Dict,
    storage_path: Optional[str] = None,
    description: Optional[str] = None,
    status: ModelStatus = ModelStatus.DRAFT,
    metrics: Optional[Dict] = None,
) -> ModelInfo:
    """
    创建新模型

    Args:
        project_id: 所属项目ID
        name: 模型名称
        version: 模型版本号
        type: 模型类型
        format: 模型格式
        config: 模型配置信息
        storage_path: 模型文件存储路径
        description: 模型描述
        status: 模型状态
        metrics: 模型评估指标

    Returns:
        ModelInfo: 创建的模型实例

    Raises:
        ResourceNotFoundError: 当项目不存在时抛出
        ResourceConflictError: 当模型名称在项目中已存在时抛出
    """
    # 检查项目是否存在
    try:
        project = await Project.get(id=project_id)
    except DoesNotExist:
        raise ResourceNotFoundError(f"项目 {project_id} 不存在")

    # 检查模型名称在项目中是否已存在
    existing = await ModelInfo.filter(project_id=project_id, name=name).first()
    if existing:
        raise ResourceConflictError(f"模型名称 '{name}' 在项目中已存在")

    # 创建新模型
    model = await ModelInfo.create(
        project=project,
        name=name,
        description=description,
        version=version,
        type=type,
        format=format,
        status=status,
        storage_path=storage_path,
        config=config,
        metrics=metrics,
    )

    return model


async def get_model(model_id: UUID) -> ModelInfo:
    """
    获取模型详情

    Args:
        model_id: 模型ID

    Returns:
        ModelInfo: 模型实例

    Raises:
        ResourceNotFoundError: 当模型不存在时抛出
    """
    try:
        model = await ModelInfo.get(id=model_id)
    except DoesNotExist:
        raise ResourceNotFoundError(f"模型 {model_id} 不存在")

    return model


async def list_models(
    project_id: Optional[UUID] = None,
    status: Optional[ModelStatus] = None,
    type: Optional[ModelType] = None,
    offset: int = 0,
    limit: int = 10,
) -> List[ModelInfo]:
    """
    获取模型列表

    Args:
        project_id: 可选的项目ID过滤
        status: 可选的模型状态过滤
        type: 可选的模型类型过滤
        offset: 分页偏移量
        limit: 分页大小

    Returns:
        List[ModelInfo]: 模型列表
    """
    query = ModelInfo.all()

    if project_id is not None:
        query = query.filter(project_id=project_id)

    if status is not None:
        query = query.filter(status=status)

    if type is not None:
        query = query.filter(type=type)

    models = await query.offset(offset).limit(limit)
    return models


async def _validate_model_name(model: ModelInfo, new_name: str) -> None:
    """验证模型名称在项目中是否可用"""
    if new_name != model.name:
        existing = await ModelInfo.filter(project_id=model.project_id, name=new_name).first()
        if existing:
            raise ResourceConflictError(f"模型名称 '{new_name}' 在项目中已存在")


def _update_model_fields(
    model: ModelInfo,
    description: Optional[str] = None,
    version: Optional[str] = None,
    type: Optional[ModelType] = None,
    format: Optional[ModelFormat] = None,
    status: Optional[ModelStatus] = None,
    storage_path: Optional[str] = None,
    config: Optional[Dict] = None,
    metrics: Optional[Dict] = None,
) -> None:
    """更新模型的字段值"""
    update_fields = {
        "description": description,
        "version": version,
        "type": type,
        "format": format,
        "status": status,
        "storage_path": storage_path,
        "config": config,
        "metrics": metrics,
    }

    for field, value in update_fields.items():
        if value is not None:
            setattr(model, field, value)


async def update_model(
    model_id: UUID,
    name: Optional[str] = None,
    description: Optional[str] = None,
    version: Optional[str] = None,
    type: Optional[ModelType] = None,
    format: Optional[ModelFormat] = None,
    status: Optional[ModelStatus] = None,
    storage_path: Optional[str] = None,
    config: Optional[Dict] = None,
    metrics: Optional[Dict] = None,
) -> ModelInfo:
    """
    更新模型信息

    Args:
        model_id: 模型ID
        name: 新的模型名称
        description: 新的模型描述
        version: 新的模型版本号
        type: 新的模型类型
        format: 新的模型格式
        status: 新的模型状态
        storage_path: 新的模型文件存储路径
        config: 新的模型配置信息
        metrics: 新的模型评估指标

    Returns:
        ModelInfo: 更新后的模型实例

    Raises:
        ResourceNotFoundError: 当模型不存在时抛出
        ResourceConflictError: 当新模型名称在项目中已存在时抛出
    """
    try:
        model = await ModelInfo.get(id=model_id)
    except DoesNotExist:
        raise ResourceNotFoundError(f"模型 {model_id} 不存在")

    # 验证名称
    if name is not None:
        await _validate_model_name(model, name)
        model.name = name

    # 更新其他字段
    _update_model_fields(
        model,
        description=description,
        version=version,
        type=type,
        format=format,
        status=status,
        storage_path=storage_path,
        config=config,
        metrics=metrics,
    )

    await model.save()
    return model


@atomic()
async def delete_model(model_id: UUID) -> None:
    """
    删除模型

    Args:
        model_id: 模型ID

    Raises:
        ResourceNotFoundError: 当模型不存在时抛出
    """
    try:
        model = await ModelInfo.get(id=model_id)
    except DoesNotExist:
        raise ResourceNotFoundError(f"模型 {model_id} 不存在")

    # 删除模型（由于设置了级联删除，相关的训练任务等资源会自动删除）
    await model.delete()


async def count_models(
    project_id: Optional[UUID] = None,
    status: Optional[ModelStatus] = None,
    type: Optional[ModelType] = None,
) -> int:
    """
    获取模型总数

    Args:
        project_id: 可选的项目ID过滤
        status: 可选的模型状态过滤
        type: 可选的模型类型过滤

    Returns:
        int: 模型总数
    """
    query = ModelInfo.all()

    if project_id is not None:
        query = query.filter(project_id=project_id)

    if status is not None:
        query = query.filter(status=status)

    if type is not None:
        query = query.filter(type=type)

    return await query.count()
