import os
import tempfile
import threading
import time

import nacos
import torch
import uvicorn
from fastapi import FastAPI, File, Form, UploadFile
from fastapi.responses import FileResponse, JSONResponse

# nacos配置
SERVER_ADDRESSES = os.getenv("NACOS_SERVER", "8.155.3.20:8848")
NAMESPACE = os.getenv("NACOS_NAMESPACE", "")
USERNAME = os.getenv("NACOS_USERNAME", "nacos")
PASSWORD = os.getenv("NACOS_PASSWORD", "nacos")
SERVICE_NAME = os.getenv("NACOS_SERVICE_NAME", "pt2onnx-service")
IP = os.getenv("NACOS_SERVICE_IP", "127.0.0.1")
PORT = int(os.getenv("NACOS_SERVICE_PORT", "8000"))

client = nacos.NacosClient(
    SERVER_ADDRESSES, namespace=NAMESPACE, username=USERNAME, password=PASSWORD
)


def register_instance():
    client.add_naming_instance(SERVICE_NAME, IP, PORT)


def deregister_instance():
    client.remove_naming_instance(SERVICE_NAME, IP, PORT)


def heartbeat():
    while True:
        try:
            client.send_heartbeat(SERVICE_NAME, IP, PORT)
        except Exception as e:
            print(f"[Nacos] 心跳失败: {e}")
        time.sleep(10)


def convert_pt_to_onnx(
    pt_path,
    onnx_path,
    input_shape=(1, 3, 224, 224),
    opset_version=11,
    do_constant_folding=True,
    external_data=True,
    optimize=False,
    input_names=None,
    output_names=None,
    dynamic_axes=None,
    verbose=False,
):
    model = torch.load(pt_path, map_location=torch.device("cpu"), weights_only=False)
    model.eval()
    dummy_input = torch.randn(*input_shape)
    torch.onnx.export(
        model,
        dummy_input,
        onnx_path,
        export_params=True,
        opset_version=opset_version,
        do_constant_folding=do_constant_folding,
        input_names=input_names or ["input"],
        output_names=output_names or ["output"],
        dynamic_axes=dynamic_axes or {"input": {0: "batch_size"}, "output": {0: "batch_size"}},
        external_data=external_data,
        verbose=verbose,
        optimize=optimize,
    )
    print(f"Exported {pt_path} to {onnx_path}")


app = FastAPI()


@app.post("/convert")
async def convert(
    pt_file: UploadFile = File(...),
    onnx_name: str = Form(...),
    c: int = Form(3),
    h: int = Form(224),
    w: int = Form(224),
    opset_version: int = Form(11),
    do_constant_folding: bool = Form(True),
    external_data: bool = Form(True),
    optimize: bool = Form(False),
    input_names: str = Form("input"),
    output_names: str = Form("output"),
    verbose: bool = Form(False),
):
    with tempfile.NamedTemporaryFile(delete=False, suffix=".pt") as pt_tmp:
        pt_tmp.write(await pt_file.read())
        pt_path = pt_tmp.name
    with tempfile.NamedTemporaryFile(delete=False, prefix=onnx_name, suffix=".onnx") as onnx_tmp:
        onnx_path = onnx_tmp.name
    try:
        convert_pt_to_onnx(
            pt_path,
            onnx_path,
            input_shape=(1, c, h, w),
            opset_version=opset_version,
            do_constant_folding=do_constant_folding,
            external_data=external_data,
            optimize=optimize,
            input_names=[i.strip() for i in input_names.split(",")],
            output_names=[o.strip() for o in output_names.split(",")],
            dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}},
            verbose=verbose,
        )
        return JSONResponse({"onnx_path": onnx_path, "msg": "success"})
    except Exception as e:
        return JSONResponse({"msg": str(e)}, status_code=500)


@app.get("/download")
def download(onnx_path: str):
    if not os.path.exists(onnx_path):
        return JSONResponse({"msg": "文件不存在"}, status_code=404)
    return FileResponse(
        onnx_path, filename=os.path.basename(onnx_path), media_type="application/octet-stream"
    )


@app.on_event("startup")
def startup_event():
    register_instance()
    threading.Thread(target=heartbeat, daemon=True).start()


@app.on_event("shutdown")
def shutdown_event():
    deregister_instance()


if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=PORT)
