import uuid
from typing import AsyncGenerator
from uuid import UUID

import pytest
from tortoise import Tortoise

from algo_flow.app.algo.crud.model import (
    count_models,
    create_model,
    delete_model,
    get_model,
    list_models,
    update_model,
)
from algo_flow.app.algo.models import Project
from algo_flow.cores.constant.algo import ModelFormat, ModelStatus, ModelType
from algo_flow.cores.exceptions import ResourceConflictError, ResourceNotFoundError

# 测试数据
test_project_id = UUID("550e8400-e29b-41d4-a716-446655440000")
test_model_data = {
    "name": "测试模型",
    "description": "这是一个测试模型",
    "version": "1.0.0",
    "type": ModelType.CLASSIFICATION,
    "format": ModelFormat.ONNX,
    "status": ModelStatus.DRAFT,
    "storage_path": "/path/to/model",
    "config": {"param1": "value1"},
    "metrics": {"accuracy": 0.95},
}


@pytest.fixture(autouse=True)
async def setup_database() -> AsyncGenerator:
    """设置测试数据库"""
    await Tortoise.init(
        db_url="sqlite://:memory:",
        modules={"models": ["app.algo.models"]},
    )
    await Tortoise.generate_schemas()

    # 创建测试项目
    await Project.create(
        id=test_project_id,
        name="测试项目",
        description="这是一个测试项目",
    )

    yield

    await Tortoise.close_connections()


@pytest.mark.asyncio
async def test_create_model():
    """测试创建模型"""
    # 创建模型
    model = await create_model(
        project_id=test_project_id,
        **test_model_data,
    )

    # 验证模型信息
    assert model.project_id == test_project_id
    assert model.name == test_model_data["name"]
    assert model.description == test_model_data["description"]
    assert model.version == test_model_data["version"]
    assert model.type == test_model_data["type"]
    assert model.format == test_model_data["format"]
    assert model.status == test_model_data["status"]
    assert model.storage_path == test_model_data["storage_path"]
    assert model.config == test_model_data["config"]
    assert model.metrics == test_model_data["metrics"]


@pytest.mark.asyncio
async def test_create_model_with_nonexistent_project():
    """测试使用不存在的项目ID创建模型"""
    with pytest.raises(ResourceNotFoundError):
        await create_model(
            project_id=uuid.uuid4(),
            **test_model_data,
        )


@pytest.mark.asyncio
async def test_create_model_with_duplicate_name():
    """测试创建同名模型"""
    # 先创建一个模型
    await create_model(
        project_id=test_project_id,
        **test_model_data,
    )

    # 尝试创建同名模型
    with pytest.raises(ResourceConflictError):
        await create_model(
            project_id=test_project_id,
            **test_model_data,
        )


@pytest.mark.asyncio
async def test_get_model():
    """测试获取模型"""
    # 先创建一个模型
    created_model = await create_model(
        project_id=test_project_id,
        **test_model_data,
    )

    # 获取模型
    model = await get_model(created_model.id)

    # 验证模型信息
    assert model.id == created_model.id
    assert model.project_id == test_project_id
    assert model.name == test_model_data["name"]


@pytest.mark.asyncio
async def test_get_nonexistent_model():
    """测试获取不存在的模型"""
    with pytest.raises(ResourceNotFoundError):
        await get_model(uuid.uuid4())


@pytest.mark.asyncio
async def test_list_models():
    """测试获取模型列表"""
    # 创建多个模型
    model1 = await create_model(
        project_id=test_project_id,
        **test_model_data,
    )
    await create_model(
        project_id=test_project_id,
        **{
            **test_model_data,
            "name": "测试模型2",
            "status": ModelStatus.READY,
        },
    )

    # 测试无过滤条件
    models = await list_models()
    assert len(models) == 2

    # 测试按项目过滤
    models = await list_models(project_id=test_project_id)
    assert len(models) == 2

    # 测试按状态过滤
    models = await list_models(status=ModelStatus.DRAFT)
    assert len(models) == 1
    assert models[0].id == model1.id

    # 测试按类型过滤
    models = await list_models(type=ModelType.CLASSIFICATION)
    assert len(models) == 2

    # 测试分页
    models = await list_models(limit=1)
    assert len(models) == 1


@pytest.mark.asyncio
async def test_update_model():
    """测试更新模型"""
    # 先创建一个模型
    model = await create_model(
        project_id=test_project_id,
        **test_model_data,
    )

    # 更新模型
    updated_data = {
        "name": "更新后的模型",
        "description": "这是更新后的描述",
        "version": "2.0.0",
        "status": ModelStatus.READY,
        "config": {"param2": "value2"},
        "metrics": {"accuracy": 0.98},
    }
    updated_model = await update_model(model.id, **updated_data)

    # 验证更新后的信息
    assert updated_model.id == model.id
    assert updated_model.name == updated_data["name"]
    assert updated_model.description == updated_data["description"]
    assert updated_model.version == updated_data["version"]
    assert updated_model.status == updated_data["status"]
    assert updated_model.config == updated_data["config"]
    assert updated_model.metrics == updated_data["metrics"]


@pytest.mark.asyncio
async def test_update_nonexistent_model():
    """测试更新不存在的模型"""
    with pytest.raises(ResourceNotFoundError):
        await update_model(
            uuid.uuid4(),
            name="新名称",
        )


@pytest.mark.asyncio
async def test_update_model_with_duplicate_name():
    """测试更新模型时使用重复的名称"""
    # 创建两个模型
    model1 = await create_model(
        project_id=test_project_id,
        **test_model_data,
    )
    model2 = await create_model(
        project_id=test_project_id,
        **{
            **test_model_data,
            "name": "测试模型2",
        },
    )

    # 尝试将 model2 的名称更新为 model1 的名称
    with pytest.raises(ResourceConflictError):
        await update_model(
            model2.id,
            name=model1.name,
        )


@pytest.mark.asyncio
async def test_delete_model():
    """测试删除模型"""
    # 先创建一个模型
    model = await create_model(
        project_id=test_project_id,
        **test_model_data,
    )

    # 删除模型
    await delete_model(model.id)

    # 验证模型已被删除
    with pytest.raises(ResourceNotFoundError):
        await get_model(model.id)


@pytest.mark.asyncio
async def test_delete_nonexistent_model():
    """测试删除不存在的模型"""
    with pytest.raises(ResourceNotFoundError):
        await delete_model(uuid.uuid4())


@pytest.mark.asyncio
async def test_count_models():
    """测试获取模型总数"""
    # 创建多个模型
    await create_model(
        project_id=test_project_id,
        **test_model_data,
    )
    await create_model(
        project_id=test_project_id,
        **{
            **test_model_data,
            "name": "测试模型2",
            "status": ModelStatus.READY,
        },
    )

    # 测试无过滤条件
    count = await count_models()
    assert count == 2

    # 测试按项目过滤
    count = await count_models(project_id=test_project_id)
    assert count == 2

    # 测试按状态过滤
    count = await count_models(status=ModelStatus.DRAFT)
    assert count == 1

    # 测试按类型过滤
    count = await count_models(type=ModelType.CLASSIFICATION)
    assert count == 2
