from __future__ import annotations

from collections import defaultdict
from dataclasses import dataclass, field
from typing import TYPE_CHECKING

from prefect.events.schemas.automations import Automation  # noqa: TC002
from pydantic import BaseModel, Field, computed_field
from typing_extensions import Self

from infrahub.core import registry
from infrahub.core.schema.schema_branch_computed import (  # noqa: TC001
    ComputedAttributeTarget,
    ComputedAttributeTriggerNode,
    PythonDefinition,
)
from infrahub.events import NodeCreatedEvent, NodeUpdatedEvent
from infrahub.trigger.constants import NAME_SEPARATOR
from infrahub.trigger.models import (
    EventTrigger,
    ExecuteWorkflow,
    TriggerBranchDefinition,
    TriggerType,
)
from infrahub.workflows.catalogue import (
    COMPUTED_ATTRIBUTE_PROCESS_JINJA2,
    COMPUTED_ATTRIBUTE_PROCESS_TRANSFORM,
    QUERY_COMPUTED_ATTRIBUTE_TRANSFORM_TARGETS,
)

if TYPE_CHECKING:
    from uuid import UUID

    from infrahub.git.models import RepositoryData


class ComputedAttributeAutomations(BaseModel):
    data: dict[str, dict[str, Automation]] = Field(default_factory=lambda: defaultdict(dict))  # type: ignore[arg-type]

    @classmethod
    def from_prefect(cls, automations: list[Automation], prefix: str = "") -> Self:
        obj = cls()
        for automation in automations:
            if not automation.name.startswith(prefix):
                continue

            name_split = automation.name.split(NAME_SEPARATOR)
            if len(name_split) != 3:
                continue

            scope = name_split[1]
            identifier = name_split[2]

            obj.data[identifier][scope] = automation

        return obj

    def get(self, identifier: str, scope: str) -> Automation:
        if identifier in self.data and scope in self.data[identifier]:
            return self.data[identifier][scope]
        raise KeyError(f"Unable to find an automation for {identifier} {scope}")

    def has(self, identifier: str, scope: str) -> bool:
        if identifier in self.data and scope in self.data[identifier]:
            return True
        return False

    @property
    def all_automation_ids(self) -> list[UUID]:
        automation_ids: list[UUID] = []
        for identifier in self.data.values():
            for automation in identifier.values():
                automation_ids.append(automation.id)
        return automation_ids


class PythonTransformComputedAttribute(BaseModel):
    name: str
    repository_id: str
    repository_name: str
    repository_kind: str
    query_name: str
    query_models: list[str]
    computed_attribute: PythonDefinition
    default_schema: bool
    branch_name: str
    branch_commit: dict[str, str] = field(default_factory=dict)

    @computed_field
    def repository_commit(self) -> str:
        return self.branch_commit[self.branch_name]

    def populate_branch_commit(self, repository_data: RepositoryData | None = None) -> None:
        if repository_data:
            for branch, commit in repository_data.branches.items():
                self.branch_commit[branch] = commit

    def get_altered_branches(self) -> list[str]:
        if registry.default_branch in self.branch_commit:
            default_branch_commit = self.branch_commit[registry.default_branch]
            return [
                branch_name for branch_name, commit in self.branch_commit.items() if commit != default_branch_commit
            ]
        return list(self.branch_commit.keys())


@dataclass
class PythonTransformTarget:
    kind: str
    object_id: str


class ComputedAttrJinja2TriggerDefinition(TriggerBranchDefinition):
    type: TriggerType = TriggerType.COMPUTED_ATTR_JINJA2
    computed_attribute: ComputedAttributeTarget

    @classmethod
    def from_computed_attribute(
        cls,
        branch: str,
        computed_attribute: ComputedAttributeTarget,
        trigger_node: ComputedAttributeTriggerNode,
        branches_out_of_scope: list[str] | None = None,
    ) -> Self:
        """
        This function is used to create a trigger definition for a computed attribute of type Jinja2.
        """
        event_trigger = EventTrigger()
        event_trigger.events.update({NodeCreatedEvent.event_name, NodeUpdatedEvent.event_name})
        event_trigger.match = {"infrahub.node.kind": trigger_node.kind}
        if branches_out_of_scope:
            event_trigger.match["infrahub.branch.name"] = [f"!{branch}" for branch in branches_out_of_scope]
        elif not branches_out_of_scope and branch != registry.default_branch:
            event_trigger.match["infrahub.branch.name"] = branch

        event_trigger.match_related = {
            "prefect.resource.role": ["infrahub.node.attribute_update", "infrahub.node.relationship_update"],
            "infrahub.field.name": trigger_node.fields,
        }

        workflow = ExecuteWorkflow(
            workflow=COMPUTED_ATTRIBUTE_PROCESS_JINJA2,
            parameters={
                "branch_name": "{{ event.resource['infrahub.branch.name'] }}",
                "node_kind": "{{ event.resource['infrahub.node.kind'] }}",
                "object_id": "{{ event.resource['infrahub.node.id'] }}",
                "computed_attribute_name": computed_attribute.attribute.name,
                "computed_attribute_kind": computed_attribute.kind,
                "updated_fields": {
                    "__prefect_kind": "json",
                    "value": {
                        "__prefect_kind": "jinja",
                        "template": "{{ event.payload['data']['fields'] | tojson }}",
                    },
                },
                "context": {
                    "__prefect_kind": "json",
                    "value": {
                        "__prefect_kind": "jinja",
                        "template": "{{ event.payload['context'] | tojson }}",
                    },
                },
            },
        )

        definition = cls(
            name=f"{computed_attribute.key_name}{NAME_SEPARATOR}kind{NAME_SEPARATOR}{trigger_node.kind}",
            branch=branch,
            computed_attribute=computed_attribute,
            trigger=event_trigger,
            actions=[workflow],
        )

        return definition


class ComputedAttrPythonTriggerDefinition(TriggerBranchDefinition):
    type: TriggerType = TriggerType.COMPUTED_ATTR_PYTHON
    computed_attribute: PythonTransformComputedAttribute

    @classmethod
    def from_object(
        cls,
        branch: str,
        computed_attribute: PythonTransformComputedAttribute,
        branches_out_of_scope: list[str] | None = None,
    ) -> Self:
        # scope = registry.default_branch

        event_trigger = EventTrigger()
        event_trigger.events.update({NodeCreatedEvent.event_name, NodeUpdatedEvent.event_name})
        event_trigger.match = {
            "infrahub.node.kind": [computed_attribute.computed_attribute.kind],
        }

        if branches_out_of_scope:
            event_trigger.match["infrahub.branch.name"] = [f"!{branch}" for branch in branches_out_of_scope]
        elif not branches_out_of_scope and branch != registry.default_branch:
            event_trigger.match["infrahub.branch.name"] = branch

        definition = cls(
            name=computed_attribute.computed_attribute.key_name,
            branch=branch,
            computed_attribute=computed_attribute,
            trigger=event_trigger,
            actions=[
                ExecuteWorkflow(
                    workflow=COMPUTED_ATTRIBUTE_PROCESS_TRANSFORM,
                    parameters={
                        "branch_name": "{{ event.resource['infrahub.branch.name'] }}",
                        "node_kind": "{{ event.resource['infrahub.node.kind'] }}",
                        "object_id": "{{ event.resource['infrahub.node.id'] }}",
                        "computed_attribute_name": computed_attribute.computed_attribute.attribute.name,
                        "computed_attribute_kind": computed_attribute.computed_attribute.kind,
                        "context": {
                            "__prefect_kind": "json",
                            "value": {
                                "__prefect_kind": "jinja",
                                "template": "{{ event.payload['context'] | tojson }}",
                            },
                        },
                    },
                ),
            ],
        )

        return definition


class ComputedAttrPythonQueryTriggerDefinition(TriggerBranchDefinition):
    type: TriggerType = TriggerType.COMPUTED_ATTR_PYTHON_QUERY

    @classmethod
    def from_object(
        cls,
        branch: str,
        computed_attribute: PythonTransformComputedAttribute,
        branches_out_of_scope: list[str] | None = None,
    ) -> Self:
        event_trigger = EventTrigger()
        event_trigger.events.update({NodeCreatedEvent.event_name, NodeUpdatedEvent.event_name})
        event_trigger.match = {
            "infrahub.node.kind": computed_attribute.query_models,
        }

        if branches_out_of_scope:
            event_trigger.match["infrahub.branch.name"] = [f"!{branch}" for branch in branches_out_of_scope]
        elif not branches_out_of_scope and branch != registry.default_branch:
            event_trigger.match["infrahub.branch.name"] = branch

        definition = cls(
            name=computed_attribute.computed_attribute.key_name,
            branch=branch,
            trigger=event_trigger,
            actions=[
                ExecuteWorkflow(
                    workflow=QUERY_COMPUTED_ATTRIBUTE_TRANSFORM_TARGETS,
                    parameters={
                        "branch_name": "{{ event.resource['infrahub.branch.name'] }}",
                        "node_kind": "{{ event.resource['infrahub.node.kind'] }}",
                        "object_id": "{{ event.resource['infrahub.node.id'] }}",
                        "context": {
                            "__prefect_kind": "json",
                            "value": {
                                "__prefect_kind": "jinja",
                                "template": "{{ event.payload['context'] | tojson }}",
                            },
                        },
                    },
                ),
            ],
        )

        return definition
