import uuid
from typing import AsyncGenerator
from uuid import UUID

import pytest
from tortoise import Tortoise

from algo_flow.app.algo.crud.training import (
    count_training_jobs,
    create_training_job,
    delete_training_job,
    get_training_job,
    list_training_jobs,
    update_training_job,
)
from algo_flow.app.algo.models import Dataset, ModelInfo, Project
from algo_flow.cores.constant.algo import (
    DatasetStatus,
    ModelFormat,
    ModelStatus,
    ModelType,
    TaskStatus,
)
from algo_flow.cores.exceptions import ResourceConflictError, ResourceNotFoundError

# 测试数据
test_project_id = UUID("550e8400-e29b-41d4-a716-446655440000")
test_dataset_id = UUID("660e8400-e29b-41d4-a716-446655440000")
test_model_id = UUID("770e8400-e29b-41d4-a716-446655440000")

test_training_data = {
    "name": "测试训练任务",
    "description": "这是一个测试训练任务",
    "hyperparameters": {
        "learning_rate": 0.001,
        "batch_size": 32,
        "epochs": 100,
        "optimizer": "adam",
        "loss": "cross_entropy",
    },
}


@pytest.fixture(autouse=True)
async def setup_database() -> AsyncGenerator:
    """设置测试数据库"""
    await Tortoise.init(
        db_url="sqlite://:memory:",
        modules={"models": ["app.algo.models"]},
    )
    await Tortoise.generate_schemas()

    # 创建测试项目
    await Project.create(
        id=test_project_id,
        name="测试项目",
        description="这是一个测试项目",
    )

    # 创建测试数据集
    await Dataset.create(
        id=test_dataset_id,
        name="测试数据集",
        description="这是一个测试数据集",
        status=DatasetStatus.READY,
        storage_path="/path/to/dataset",
        metadata={
            "format": "COCO",
            "size": 1000,
        },
    )

    # 创建测试模型
    await ModelInfo.create(
        id=test_model_id,
        project_id=test_project_id,
        name="测试模型",
        description="这是一个测试模型",
        version="1.0.0",
        type=ModelType.CLASSIFICATION,
        format=ModelFormat.PYTORCH,
        status=ModelStatus.DRAFT,
        storage_path="/path/to/model",
        config={
            "architecture": "resnet50",
            "input_size": [224, 224, 3],
        },
    )

    yield

    await Tortoise.close_connections()


@pytest.mark.asyncio
async def test_create_training_job():
    """测试创建训练任务"""
    # 创建训练任务
    job = await create_training_job(
        model_id=test_model_id,
        dataset_id=test_dataset_id,
        **test_training_data,
    )

    # 验证训练任务信息
    assert job.name == test_training_data["name"]
    assert job.description == test_training_data["description"]
    assert job.hyperparameters == test_training_data["hyperparameters"]
    assert job.status == TaskStatus.PENDING
    assert job.model_id == test_model_id
    assert job.dataset_id == test_dataset_id


@pytest.mark.asyncio
async def test_create_training_job_with_nonexistent_model():
    """测试使用不存在的模型ID创建训练任务"""
    with pytest.raises(ResourceNotFoundError):
        await create_training_job(
            model_id=uuid.uuid4(),
            dataset_id=test_dataset_id,
            **test_training_data,
        )


@pytest.mark.asyncio
async def test_create_training_job_with_nonexistent_dataset():
    """测试使用不存在的数据集ID创建训练任务"""
    with pytest.raises(ResourceNotFoundError):
        await create_training_job(
            model_id=test_model_id,
            dataset_id=uuid.uuid4(),
            **test_training_data,
        )


@pytest.mark.asyncio
async def test_create_training_job_with_duplicate_name():
    """测试创建同名训练任务"""
    # 先创建一个训练任务
    await create_training_job(
        model_id=test_model_id,
        dataset_id=test_dataset_id,
        **test_training_data,
    )

    # 尝试创建同名训练任务
    with pytest.raises(ResourceConflictError):
        await create_training_job(
            model_id=test_model_id,
            dataset_id=test_dataset_id,
            **test_training_data,
        )


@pytest.mark.asyncio
async def test_get_training_job():
    """测试获取训练任务"""
    # 先创建一个训练任务
    created_job = await create_training_job(
        model_id=test_model_id,
        dataset_id=test_dataset_id,
        **test_training_data,
    )

    # 获取训练任务
    job = await get_training_job(created_job.id)

    # 验证训练任务信息
    assert job.id == created_job.id
    assert job.name == test_training_data["name"]
    assert job.model_id == test_model_id
    assert job.dataset_id == test_dataset_id


@pytest.mark.asyncio
async def test_get_nonexistent_training_job():
    """测试获取不存在的训练任务"""
    with pytest.raises(ResourceNotFoundError):
        await get_training_job(uuid.uuid4())


@pytest.mark.asyncio
async def test_list_training_jobs():
    """测试获取训练任务列表"""
    # 创建多个训练任务
    await create_training_job(
        model_id=test_model_id,
        dataset_id=test_dataset_id,
        **test_training_data,
    )
    await create_training_job(
        model_id=test_model_id,
        dataset_id=test_dataset_id,
        **{
            **test_training_data,
            "name": "测试训练任务2",
        },
    )

    # 测试无过滤条件
    jobs = await list_training_jobs()
    assert len(jobs) == 2

    # 测试按模型过滤
    jobs = await list_training_jobs(model_id=test_model_id)
    assert len(jobs) == 2

    # 测试按数据集过滤
    jobs = await list_training_jobs(dataset_id=test_dataset_id)
    assert len(jobs) == 2

    # 测试按状态过滤
    jobs = await list_training_jobs(status=TaskStatus.PENDING)
    assert len(jobs) == 2

    # 测试分页
    jobs = await list_training_jobs(limit=1)
    assert len(jobs) == 1


@pytest.mark.asyncio
async def test_update_training_job():
    """测试更新训练任务"""
    # 先创建一个训练任务
    job = await create_training_job(
        model_id=test_model_id,
        dataset_id=test_dataset_id,
        **test_training_data,
    )

    # 更新训练任务
    updated_data = {
        "name": "更新后的训练任务",
        "description": "这是更新后的描述",
        "status": TaskStatus.RUNNING,
        "hyperparameters": {
            "learning_rate": 0.0005,
            "batch_size": 64,
            "epochs": 200,
            "optimizer": "sgd",
            "loss": "focal",
        },
        "metrics": {
            "train_loss": 0.1,
            "train_accuracy": 0.95,
            "val_loss": 0.2,
            "val_accuracy": 0.9,
        },
    }
    updated_job = await update_training_job(job.id, **updated_data)

    # 验证更新后的信息
    assert updated_job.id == job.id
    assert updated_job.name == updated_data["name"]
    assert updated_job.description == updated_data["description"]
    assert updated_job.status == updated_data["status"]
    assert updated_job.hyperparameters == updated_data["hyperparameters"]
    assert updated_job.metrics == updated_data["metrics"]


@pytest.mark.asyncio
async def test_update_nonexistent_training_job():
    """测试更新不存在的训练任务"""
    with pytest.raises(ResourceNotFoundError):
        await update_training_job(
            uuid.uuid4(),
            name="新名称",
        )


@pytest.mark.asyncio
async def test_update_training_job_with_duplicate_name():
    """测试更新训练任务时使用重复的名称"""
    # 创建两个训练任务
    job1 = await create_training_job(
        model_id=test_model_id,
        dataset_id=test_dataset_id,
        **test_training_data,
    )
    job2 = await create_training_job(
        model_id=test_model_id,
        dataset_id=test_dataset_id,
        **{
            **test_training_data,
            "name": "测试训练任务2",
        },
    )

    # 尝试将 job2 的名称更新为 job1 的名称
    with pytest.raises(ResourceConflictError):
        await update_training_job(
            job2.id,
            name=job1.name,
        )


@pytest.mark.asyncio
async def test_delete_training_job():
    """测试删除训练任务"""
    # 先创建一个训练任务
    job = await create_training_job(
        model_id=test_model_id,
        dataset_id=test_dataset_id,
        **test_training_data,
    )

    # 删除训练任务
    await delete_training_job(job.id)

    # 验证训练任务已被删除
    with pytest.raises(ResourceNotFoundError):
        await get_training_job(job.id)


@pytest.mark.asyncio
async def test_delete_nonexistent_training_job():
    """测试删除不存在的训练任务"""
    with pytest.raises(ResourceNotFoundError):
        await delete_training_job(uuid.uuid4())


@pytest.mark.asyncio
async def test_count_training_jobs():
    """测试获取训练任务总数"""
    # 创建多个训练任务
    await create_training_job(
        model_id=test_model_id,
        dataset_id=test_dataset_id,
        **test_training_data,
    )
    await create_training_job(
        model_id=test_model_id,
        dataset_id=test_dataset_id,
        **{
            **test_training_data,
            "name": "测试训练任务2",
        },
    )

    # 测试无过滤条件
    count = await count_training_jobs()
    assert count == 2

    # 测试按模型过滤
    count = await count_training_jobs(model_id=test_model_id)
    assert count == 2

    # 测试按数据集过滤
    count = await count_training_jobs(dataset_id=test_dataset_id)
    assert count == 2

    # 测试按状态过滤
    count = await count_training_jobs(status=TaskStatus.PENDING)
    assert count == 2
