from __future__ import annotations

from copy import deepcopy
from dataclasses import dataclass
from functools import cached_property

import numpy as np
from numpy.typing import NDArray
from sortedcontainers import SortedSet

from max_div.sampling._constraint_helpers import _build_array_repr, _build_con_membership

from ._constraints import Constraint
from ._distance import (
    DistanceMetric,
    compute_pdist,
    compute_separation,
    update_separation_add,
    update_separation_remove,
)
from ._diversity import DiversityMetric
from ._score import Score, ScoreGenerator


# =================================================================================================
#  Solver State
# =================================================================================================
class SolverState:
    """Abstract base class for solver state management."""

    # -------------------------------------------------------------------------
    #  Construction & Configuration
    # -------------------------------------------------------------------------
    def __init__(
        self,
        n: np.int32,
        k: np.int32,
        pdist: NDArray[np.float32],
        score_generator: ScoreGenerator,
        selected: SortedSet,
        not_selected: SortedSet,
        sep_global: NDArray[np.float32],
        sep_selected: NDArray[np.float32],
        con_values: NDArray[np.int32],
        con_indices: NDArray[np.int32],
        con_membership: dict[np.int32, list[np.int32]],
    ):
        """
        Initialize the SolverState.  The constructor is not intended to be used directly, instead use new().

        Problem Dimensions:

            n  : total number of vectors
          ( d  : dimensionality of each vector  (not visible here, since we get pair-wise distances directly) )
            k  : target selection size
            m : number of constraints

        :param n: (np.int32) number of vectors
        :param k: (np.int32) target number of selected vectors
        :param pdist: (np.ndarray[np.float32]) condensed pair-wise distance vector (1D array of size (n*(n-1))//2)
        :param score_generator: (ScoreGenerator) score generator to compute scores for current state
        :param selected: (SortedSet) set of selected indices (np.int32)
        :param not_selected: (SortedSet) set of not selected indices (np.int32)
        :param sep_global: (np.ndarray[np.float32]) n x 1 array with separation of each vector wrt the others
        :param sep_selected: (np.ndarray[np.float32]) n x 1 array with separation of each vector wrt selected set
        :param con_values: (np.ndarray[np.int32] | None) upper/lower bounds per constraint (m x 2 array of float32)
                                                                     (as generated by _build_array_repr)
        :param con_indices: (np.ndarray[np.int32] | None) 1d array with all indices per constraint
                                                                     (as generated by _build_array_repr)
        :param con_membership: (dict[np.int32, list[np.int32]] | None) mapping from index to list of constraints it belongs to
                                                                     (as generated by _build_con_membership)
        """
        self._n = n  # READ-ONLY
        self._k = k  # READ-ONLY

        # distances
        self._pdist = pdist  # READ-ONLY
        self._sep_global = sep_global  # READ-ONLY
        self._sep_selected = sep_selected

        # scoring
        self._score_generator = score_generator  # READ-ONLY
        self._score: Score | None = None

        # selection
        self._selected = selected
        self._not_selected = not_selected

        # constraints
        self._con_values = con_values  # min/max counts of extra samples needed on top of current selection
        self._con_indices = con_indices  # READ-ONLY
        self._con_membership = con_membership  # READ-ONLY

        # snapshot
        self._snapshot: Snapshot = Snapshot.empty()

        # finalize
        self._update_score()

    # -------------------------------------------------------------------------
    #  Copy
    # -------------------------------------------------------------------------
    def copy(self) -> SolverState:
        """Returns a deep copy of the current state."""
        return SolverState(
            n=self._n,
            k=self._k,
            pdist=self._pdist.copy(),
            score_generator=self._score_generator.copy(),
            selected=self._selected.copy(),
            not_selected=self._not_selected.copy(),
            sep_global=self._sep_global.copy(),
            sep_selected=self._sep_selected.copy(),
            con_values=self._con_values.copy(),
            con_indices=self._con_indices.copy(),
            con_membership=deepcopy(self._con_membership),
        )

    # -------------------------------------------------------------------------
    #  Main API - used by solver strategies to modify state
    # -------------------------------------------------------------------------
    def set_snapshot(self):
        """
        When called, this method internally saves the current state as a snapshot (possibly overwriting any previous
        snapshot).  Such a snapshot can be restored using the restore_snapshot() method,  with any actions that happened
        in between (add, remove) being undone.
        """

        # NOTE: we create copies, such that add(.) and remove(.) cannot influence the snapshot after it was taken
        self._snapshot.selected = self._selected.copy()
        self._snapshot.not_selected = self._not_selected.copy()
        self._snapshot.sep_selected = self._sep_selected.copy()
        self._snapshot.con_values = self._con_values.copy()
        self._snapshot.is_valid = True

    def restore_snapshot(self):
        """
        This method restores the state of this object to the state saved in the last call to set_snapshot().
        Any actions that happened in between (add, remove) are undone.  If set_snapshot() hasn't been called before,
        a ValueError is raised.  After restoring the snapshot, it gets cleared, such that subsequent calls to
        restore_snapshot() without an intermediate call to set_snapshot() will again raise a ValueError.
        """
        if not self._snapshot.is_valid:
            raise ValueError("Cannot restore snapshot: set_snapshot() not called before.")

        # restore snapshot (no copy needed; we will clear the snapshot)
        self._selected = self._snapshot.selected
        self._not_selected = self._snapshot.not_selected
        self._sep_selected = self._snapshot.sep_selected
        self._con_values = self._snapshot.con_values

        # restore score
        self._update_score()

        # clear snapshot after restoring
        self._snapshot.clear()

    def add(self, index: int | np.int32):
        # --- validation ----------------------------------
        index = np.int32(index)
        if index in self._selected:
            raise ValueError("Cannot add index that is already selected.")

        # --- selection -----------------------------------
        self._selected.add(index)
        self._not_selected.remove(index)

        # --- separation ----------------------------------
        update_separation_add(self._sep_selected, self._pdist, self._n, index)

        # --- constraints ---------------------------------
        for i_con in self._con_membership[index]:
            self._con_values[i_con, 0] -= 1  # decrease min_count
            self._con_values[i_con, 1] -= 1  # decrease max_count

        # --- score ---------------------------------------
        self._update_score()

    def remove(self, index: int | np.int32):
        # --- validation ----------------------------------
        index = np.int32(index)
        if index in self._not_selected:
            raise ValueError("Cannot remove index that is not selected.")

        # --- selection -----------------------------------
        self._selected.remove(index)
        self._not_selected.add(index)

        # --- separation ----------------------------------
        update_separation_remove(self._sep_selected, self._pdist, self._n, index, self.selected_index_array)

        # --- constraints ---------------------------------
        for i_con in self._con_membership[index]:
            self._con_values[i_con, 0] += 1  # increase min_count
            self._con_values[i_con, 1] += 1  # increase max_count

        # --- score ---------------------------------------
        self._update_score()

    # -------------------------------------------------------------------------
    #  Properties
    # -------------------------------------------------------------------------
    @cached_property
    def k(self) -> np.int32:
        """Return target selection size."""
        return self._k

    @cached_property
    def m(self) -> np.int32:
        """Return total number of constraints."""
        return self._con_values.shape[0]

    @cached_property
    def n(self) -> np.int32:
        """Return total number of vectors."""
        return self._n

    @cached_property
    def has_constraints(self) -> bool:
        """Return True if >0 constraints are defined."""
        return self._con_values.shape[0] > 0

    @cached_property
    def con_indices(self) -> NDArray[np.int32]:
        """Return constraint indices array."""
        return self._con_indices  # should not be modified (!)

    @property  # not cached, since this array is expected to change
    def con_values(self) -> NDArray[np.int32]:
        """Return constraint indices array."""
        return self._con_values  # should not be modified (!)

    @property
    def selected_index_array(self) -> NDArray[np.int32]:
        """Return selected indices as a numpy array of np.int32."""
        return np.array(self._selected, dtype=np.int32)

    @property
    def not_selected_index_array(self) -> NDArray[np.int32]:
        """Return not selected indices as a numpy array of np.int32."""
        return np.array(self._not_selected, dtype=np.int32)

    @property
    def selected_separation_array(self) -> NDArray[np.float32]:
        """Return separation of selected vectors wrt other selected vectors as a numpy array of np.float32."""
        return self._sep_selected[list(self._selected)]

    @property
    def not_selected_separation_array(self) -> NDArray[np.float32]:
        """Return separation of not selected vectors wrt selected vectors as a numpy array of np.float32."""
        return self._sep_selected[list(self._not_selected)]

    @property
    def global_separation_array(self) -> NDArray[np.float32]:
        """Return global separation of all vectors wrt all other vectors as a numpy array of np.float32."""
        return self._sep_global  # should not be modified (!)

    # -------------------------------------------------------------------------
    #  Scoring
    # -------------------------------------------------------------------------
    def _update_score(self):
        self._score = self._score_generator.compute_score(
            n_selected=len(self._selected),
            con_values=self._con_values,
            selected_separation_array=self.selected_separation_array,
        )

    @property
    def score(self) -> Score:
        """
        Return overall score of the current selection as a multi-component prioritized Score object.
        """
        return self._score

    # -------------------------------------------------------------------------
    #  Factory methods
    # -------------------------------------------------------------------------
    @classmethod
    def new(
        cls,
        vectors: np.ndarray,
        k: int,
        distance_metric: DistanceMetric,
        diversity_metric: DiversityMetric,
        diversity_tie_breakers: list[DiversityMetric],
        constraints: list[Constraint],
    ) -> SolverState:
        # --- distances ---
        n = np.int32(vectors.shape[0])
        pdist = compute_pdist(vectors, distance_metric)
        sep_global = compute_separation(pdist, n)
        sep_selected = np.full(n, fill_value=np.inf, dtype=np.float32)

        # --- selection ---
        selected = SortedSet()
        not_selected = SortedSet(np.arange(n, dtype=np.int32))

        # --- constraints ---
        con_values, con_indices = _build_array_repr(constraints)
        con_membership = _build_con_membership(n, constraints)

        # --- score generator ---
        score_generator = ScoreGenerator(
            n=n,
            k=k,
            diversity_metric=diversity_metric,
            diversity_tie_breakers=diversity_tie_breakers,
            constraints=constraints,
        )

        # --- construct & return ---
        return SolverState(
            n=n,
            k=np.int32(k),
            pdist=pdist,
            score_generator=score_generator,
            selected=selected,
            not_selected=not_selected,
            sep_global=sep_global,
            sep_selected=sep_selected,
            con_values=con_values,
            con_indices=con_indices,
            con_membership=con_membership,
        )


# =================================================================================================
#  Helper Classes
# =================================================================================================
@dataclass
class Snapshot:
    """
    Class internally used by SolverState to store snapshots of its state.  This class models a subset of the fields
    of the SolverState class, restricting itself to those that can be modified after construction.
    """

    is_valid: bool

    selected: SortedSet  # sorted set of np.int32
    not_selected: SortedSet  # sorted set of np.int32

    sep_selected: NDArray[np.float32]  # m-sized 1D array with separation of each vector wrt selected set
    con_values: NDArray[np.int32]  # (nc x 2)-sized array with current status of constraint bounds

    # -------------------------------------------------------------------------
    #  Modification / Factory
    # -------------------------------------------------------------------------
    def clear(self):
        """Clear the snapshot, making it invalid."""
        self.is_valid = False
        self.selected = _EMPTY_SORTED_SET
        self.not_selected = _EMPTY_SORTED_SET
        self.sep_selected = _EMPTY_NP_ARRAY_FLOAT32
        self.con_values = _EMPTY_NP_ARRAY_INT32

    @classmethod
    def empty(cls) -> Snapshot:
        """Create and return an empty/invalid snapshot."""
        return Snapshot(
            is_valid=False,
            selected=_EMPTY_SORTED_SET,
            not_selected=_EMPTY_SORTED_SET,
            sep_selected=_EMPTY_NP_ARRAY_FLOAT32,
            con_values=_EMPTY_NP_ARRAY_INT32,
        )


# singletons to avoid repeated, unnecessary allocations
_EMPTY_SORTED_SET = SortedSet()
_EMPTY_NP_ARRAY_INT32 = np.array([], dtype=np.int32)
_EMPTY_NP_ARRAY_FLOAT32 = np.array([], dtype=np.float32)
