from typing import List, Optional, Set

from fastapi import APIRouter, Depends, HTTPException, Security
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm, SecurityScopes
from passlib.exc import InvalidTokenError
from pydantic import BaseModel, Field, ValidationError
from starlette import status
from tortoise.expressions import Q

from algo_flow.app.system.models import User
from algo_flow.cores.jwt import Token, create_access_token, verify_token
from algo_flow.cores.oauth.github import (
    OAuth2GithubRequestForm,
    get_access_token,
    get_primary_email_by_access_token,
)
from algo_flow.cores.pwd import verify_password
from algo_flow.cores.scope import filter_scopes, scopes

auth_router = APIRouter()

oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/oauth2/password", scopes=scopes)


class TokenData(BaseModel):
    username: Optional[str] = None
    scopes: List[str] = Field(default_factory=list)


async def authenticate_user(username: str, password: str) -> Optional[User]:
    """验证用户"""
    user = (
        await User.get_queryset()
        .prefetch_related("roles__permissions")
        .filter(username=username)
        .first()
    )
    if not user:
        return None
    if not verify_password(password, user.hashed_password):
        return None
    return user


async def authenticate_user_by_oauth(username: str) -> Optional[User]:
    """通过OAuth验证用户"""
    user = (
        await User.get_queryset()
        .prefetch_related("roles__permissions")
        .get_or_none(Q(username=username) | Q(email=username))
    )
    if not user:
        return None
    return user


async def get_user_permissions(user: User) -> Set[str]:
    """获取用户权限"""
    permissions = set()
    for role in user.roles:
        for permission in role.permissions:
            permissions.add(permission.name)
    return permissions


async def get_current_user(
    security_scopes: SecurityScopes, token: str = Depends(oauth2_scheme)
) -> User:
    """获取当前用户"""
    if security_scopes.scopes:
        authenticate_value = f'Bearer scope="{security_scopes.scope_str}"'
    else:
        authenticate_value = "Bearer"

    credentials_exception = HTTPException(
        status_code=status.HTTP_401_UNAUTHORIZED,
        detail="无法验证凭据",
        headers={"WWW-Authenticate": authenticate_value},
    )

    try:
        payload = verify_token(token)
        username: str = payload.get("sub")
        if username is None:
            raise credentials_exception
        token_scopes = payload.get("scopes", [])
        token_data = TokenData(scopes=token_scopes, username=username)
    except (InvalidTokenError, ValidationError):
        raise credentials_exception

    user = await User.get_queryset().get_or_none(username=username)
    if user is None:
        raise credentials_exception

    if user.is_superuser:
        return user

    for scope in security_scopes.scopes:
        for user_scope in token_data.scopes:
            if scope == user_scope or scope.startswith(f"{user_scope}:"):
                break
        else:
            raise HTTPException(
                status_code=status.HTTP_401_UNAUTHORIZED,
                detail="权限不足",
                headers={"WWW-Authenticate": authenticate_value},
            )

    return user


async def get_current_active_user(
    current_user: User = Security(get_current_user),
) -> User:
    """获取当前活跃用户"""
    if current_user.is_active:
        return current_user
    raise HTTPException(status_code=400, detail="用户未激活")


@auth_router.post("/password", response_model=Token)
async def login_from_password(
    form_data: OAuth2PasswordRequestForm = Depends(),
):
    user = await authenticate_user(form_data.username, form_data.password)
    if not user:
        raise HTTPException(status_code=400, detail="Incorrect username or password")

    permissions = await get_user_permissions(user)
    filter_permissions = filter_scopes(permissions)

    access_token = create_access_token(data={"sub": user.username, "scopes": filter_permissions})
    return Token(
        access_token=access_token,
        token_type="bearer",
        scopes=filter_permissions,
    )


@auth_router.post("/oauth2/password", response_model=Token)
async def login_from_oauth2_password(
    form_data: OAuth2PasswordRequestForm = Depends(),
):
    user = await authenticate_user(form_data.username, form_data.password)
    if not user:
        raise HTTPException(status_code=400, detail="Incorrect username or password")

    permissions = await get_user_permissions(user)
    permissions &= set(form_data.scopes)
    if len(permissions) < len(form_data.scopes):
        raise HTTPException(
            status_code=400, detail=f"Incorrect permission {set(form_data.scopes) - permissions}"
        )

    filter_permissions = filter_scopes(permissions)
    access_token = create_access_token(data={"sub": user.username, "scopes": filter_permissions})
    return Token(
        access_token=access_token,
        token_type="bearer",
        scopes=filter_permissions,
    )


@auth_router.post("/oauth2/github", response_model=Token)
async def login_from_oauth2_github(
    form_data: OAuth2GithubRequestForm,
):
    # 验证state，防止CSRF攻击
    # TODO: 实现state验证逻辑

    # 通过code获取access token
    token = await get_access_token(form_data.code)
    # 通过access token获取邮箱
    primary_email = await get_primary_email_by_access_token(token)
    # 通过email查询用户
    user = await authenticate_user_by_oauth(primary_email.email)
    if not user:
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST, detail="User not found or not active"
        )

    if not user.is_active:
        raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="User is inactive")

    permissions = await get_user_permissions(user)
    filter_permissions = filter_scopes(permissions)

    access_token = create_access_token(data={"sub": user.username, "scopes": filter_permissions})
    return Token(
        access_token=access_token,
        token_type="bearer",
        scopes=filter_permissions,
    )


@auth_router.post("/logout")
async def logout():
    return None
