import uuid
from typing import AsyncGenerator

import pytest
from tortoise import Tortoise

from algo_flow.app.algo.crud.compute import (
    count_compute_nodes,
    create_compute_node,
    delete_compute_node,
    get_compute_node,
    list_compute_nodes,
    update_compute_node,
)
from algo_flow.cores.constant.algo import ComputeNodeStatus, ComputeNodeType
from algo_flow.cores.exceptions import ResourceConflictError, ResourceNotFoundError

# 测试数据
test_compute_data = {
    "name": "测试计算节点",
    "description": "这是一个测试计算节点",
    "type": ComputeNodeType.GPU,
    "ip_address": "192.168.1.100",
    "port": 8000,
    "resources": {
        "gpu": {
            "count": 4,
            "memory": 16384,
            "model": "NVIDIA A100",
        },
        "cpu": {
            "cores": 32,
            "memory": 131072,
        },
        "disk": {
            "total": 1024000,
            "available": 512000,
        },
    },
}


@pytest.fixture(autouse=True)
async def setup_database() -> AsyncGenerator:
    """设置测试数据库"""
    await Tortoise.init(
        db_url="sqlite://:memory:",
        modules={"models": ["app.algo.models"]},
    )
    await Tortoise.generate_schemas()

    yield

    await Tortoise.close_connections()


@pytest.mark.asyncio
async def test_create_compute_node():
    """测试创建计算节点"""
    # 创建计算节点
    node = await create_compute_node(**test_compute_data)

    # 验证计算节点信息
    assert node.name == test_compute_data["name"]
    assert node.description == test_compute_data["description"]
    assert node.type == test_compute_data["type"]
    assert node.ip_address == test_compute_data["ip_address"]
    assert node.port == test_compute_data["port"]
    assert node.resources == test_compute_data["resources"]
    assert node.status == ComputeNodeStatus.OFFLINE


@pytest.mark.asyncio
async def test_create_compute_node_with_duplicate_name():
    """测试创建同名计算节点"""
    # 先创建一个计算节点
    await create_compute_node(**test_compute_data)

    # 尝试创建同名计算节点
    with pytest.raises(ResourceConflictError):
        await create_compute_node(**test_compute_data)


@pytest.mark.asyncio
async def test_create_compute_node_with_duplicate_ip_port():
    """测试创建具有相同IP和端口的计算节点"""
    # 先创建一个计算节点
    await create_compute_node(**test_compute_data)

    # 尝试创建具有相同IP和端口的计算节点
    with pytest.raises(ResourceConflictError):
        await create_compute_node(
            **{
                **test_compute_data,
                "name": "另一个计算节点",
            }
        )


@pytest.mark.asyncio
async def test_get_compute_node():
    """测试获取计算节点"""
    # 先创建一个计算节点
    created_node = await create_compute_node(**test_compute_data)

    # 获取计算节点
    node = await get_compute_node(created_node.id)

    # 验证计算节点信息
    assert node.id == created_node.id
    assert node.name == test_compute_data["name"]
    assert node.type == test_compute_data["type"]
    assert node.ip_address == test_compute_data["ip_address"]
    assert node.port == test_compute_data["port"]


@pytest.mark.asyncio
async def test_get_nonexistent_compute_node():
    """测试获取不存在的计算节点"""
    with pytest.raises(ResourceNotFoundError):
        await get_compute_node(uuid.uuid4())


@pytest.mark.asyncio
async def test_list_compute_nodes():
    """测试获取计算节点列表"""
    # 创建多个计算节点
    await create_compute_node(**test_compute_data)
    await create_compute_node(
        **{
            **test_compute_data,
            "name": "测试计算节点2",
            "ip_address": "192.168.1.101",
        }
    )

    # 测试无过滤条件
    nodes = await list_compute_nodes()
    assert len(nodes) == 2

    # 测试按类型过滤
    nodes = await list_compute_nodes(type=ComputeNodeType.GPU)
    assert len(nodes) == 2

    # 测试按状态过滤
    nodes = await list_compute_nodes(status=ComputeNodeStatus.OFFLINE)
    assert len(nodes) == 2

    # 测试分页
    nodes = await list_compute_nodes(limit=1)
    assert len(nodes) == 1


@pytest.mark.asyncio
async def test_update_compute_node():
    """测试更新计算节点"""
    # 先创建一个计算节点
    node = await create_compute_node(**test_compute_data)

    # 更新计算节点
    updated_data = {
        "name": "更新后的计算节点",
        "description": "这是更新后的描述",
        "status": ComputeNodeStatus.ONLINE,
        "resources": {
            "gpu": {
                "count": 8,
                "memory": 32768,
                "model": "NVIDIA A100",
            },
            "cpu": {
                "cores": 64,
                "memory": 262144,
            },
            "disk": {
                "total": 2048000,
                "available": 1024000,
            },
        },
    }
    updated_node = await update_compute_node(node.id, **updated_data)

    # 验证更新后的信息
    assert updated_node.id == node.id
    assert updated_node.name == updated_data["name"]
    assert updated_node.description == updated_data["description"]
    assert updated_node.status == updated_data["status"]
    assert updated_node.resources == updated_data["resources"]


@pytest.mark.asyncio
async def test_update_nonexistent_compute_node():
    """测试更新不存在的计算节点"""
    with pytest.raises(ResourceNotFoundError):
        await update_compute_node(
            uuid.uuid4(),
            name="新名称",
        )


@pytest.mark.asyncio
async def test_update_compute_node_with_duplicate_name():
    """测试更新计算节点时使用重复的名称"""
    # 创建两个计算节点
    node1 = await create_compute_node(**test_compute_data)
    node2 = await create_compute_node(
        **{
            **test_compute_data,
            "name": "测试计算节点2",
            "ip_address": "192.168.1.101",
        }
    )

    # 尝试将 node2 的名称更新为 node1 的名称
    with pytest.raises(ResourceConflictError):
        await update_compute_node(
            node2.id,
            name=node1.name,
        )


@pytest.mark.asyncio
async def test_update_compute_node_with_duplicate_ip_port():
    """测试更新计算节点时使用重复的IP和端口"""
    # 创建两个计算节点
    node1 = await create_compute_node(**test_compute_data)
    node2 = await create_compute_node(
        **{
            **test_compute_data,
            "name": "测试计算节点2",
            "ip_address": "192.168.1.101",
        }
    )

    # 尝试将 node2 的IP和端口更新为 node1 的IP和端口
    with pytest.raises(ResourceConflictError):
        await update_compute_node(
            node2.id,
            ip_address=node1.ip_address,
            port=node1.port,
        )


@pytest.mark.asyncio
async def test_delete_compute_node():
    """测试删除计算节点"""
    # 先创建一个计算节点
    node = await create_compute_node(**test_compute_data)

    # 删除计算节点
    await delete_compute_node(node.id)

    # 验证计算节点已被删除
    with pytest.raises(ResourceNotFoundError):
        await get_compute_node(node.id)


@pytest.mark.asyncio
async def test_delete_nonexistent_compute_node():
    """测试删除不存在的计算节点"""
    with pytest.raises(ResourceNotFoundError):
        await delete_compute_node(uuid.uuid4())


@pytest.mark.asyncio
async def test_count_compute_nodes():
    """测试获取计算节点总数"""
    # 创建多个计算节点
    await create_compute_node(**test_compute_data)
    await create_compute_node(
        **{
            **test_compute_data,
            "name": "测试计算节点2",
            "ip_address": "192.168.1.101",
        }
    )

    # 测试无过滤条件
    count = await count_compute_nodes()
    assert count == 2

    # 测试按类型过滤
    count = await count_compute_nodes(type=ComputeNodeType.GPU)
    assert count == 2

    # 测试按状态过滤
    count = await count_compute_nodes(status=ComputeNodeStatus.OFFLINE)
    assert count == 2


@pytest.mark.asyncio
async def test_update_compute_node_partial_fields():
    """测试更新计算节点的部分字段"""
    # 先创建一个计算节点
    node = await create_compute_node(**test_compute_data)

    # 只更新状态字段
    updated_node = await update_compute_node(
        node.id,
        status=ComputeNodeStatus.ONLINE,
    )
    assert updated_node.status == ComputeNodeStatus.ONLINE
    assert updated_node.name == test_compute_data["name"]  # 其他字段保持不变

    # 只更新类型字段
    updated_node = await update_compute_node(
        node.id,
        type=ComputeNodeType.CPU,
    )
    assert updated_node.type == ComputeNodeType.CPU
    assert updated_node.status == ComputeNodeStatus.ONLINE  # 之前更新的状态保持不变

    # 只更新描述字段
    new_description = "新的描述"
    updated_node = await update_compute_node(
        node.id,
        description=new_description,
    )
    assert updated_node.description == new_description
    assert updated_node.type == ComputeNodeType.CPU  # 之前更新的类型保持不变


@pytest.mark.asyncio
async def test_update_compute_node_ip_or_port():
    """测试只更新计算节点的 IP 地址或端口"""
    # 先创建一个计算节点
    node = await create_compute_node(**test_compute_data)

    # 只更新 IP 地址
    new_ip = "192.168.1.200"
    updated_node = await update_compute_node(
        node.id,
        ip_address=new_ip,
    )
    assert updated_node.ip_address == new_ip
    assert updated_node.port == test_compute_data["port"]  # 端口保持不变

    # 只更新端口
    new_port = 9000
    updated_node = await update_compute_node(
        node.id,
        port=new_port,
    )
    assert updated_node.ip_address == new_ip  # IP 保持更新后的值
    assert updated_node.port == new_port
