import numpy as np
from dataclasses import dataclass
from typing import Optional, Literal, Dict, Any
import warnings

from .utils.functions import simplex_matrix  # type: ignore
from .penalties import make_penalty, Penalty  # you already have this
from .solvers import fista, FISTAOptions  # proximal gradient on B
from .solvers import bm_sen_solve, BMSENOptions  # BM factorization for SEN

Array = np.ndarray
PenaltyName = Optional[Literal["l1", "l2", "elasticnet", "sen", None]]
SolverName = Literal["fista", "bm"]


@dataclass
class SMLR:
    """Simplex-Based Multinomial Logistic Regression (penalized).

    This version cleanly separates three concerns:
    - likelihood / SMLR geometry (this class)
    - penalties (L1/L2/ElasticNet/SEN) via `Penalty` objects
    - solvers (FISTA on B, BM factorization for SEN)

    Parameters
    ----------
    penalty : {'l1', 'l2', 'elasticnet', 'sen', None}
        Type of regularization. "sen" = spectral elastic net.
    lam : float
        Overall regularization strength (glmnet's λ).
        For `penalty='sen'` this is interpreted as the λ1 (nuclear norm) part;
        the λ2 (Frobenius) part is given by `sen_l2`.
    alpha : float, optional
        Elastic net mixing (only for penalty='elasticnet').
    sen_l2 : float, optional
        SEN Frobenius coefficient λ2. Only used if penalty == 'sen'.
    solver : {'fista', 'bm'}
        Optimization backend.
        - 'fista': proximal gradient on B (requires SVD for SEN)
        - 'bm': Burer–Monteiro factorization, SVD-free, **only** valid for SEN.
    fit_intercept : bool
        Whether to include an intercept row in β (unpenalized).
    rank : int, optional
        Rank for BM factorization (B = U V^T). Only used when solver='bm'.
    """

    penalty: PenaltyName = None
    lam: float = 0.0
    alpha: Optional[float] = None
    sen_l2: float = 0.0
    solver: SolverName = "fista"
    fit_intercept: bool = True
    rank: Optional[int] = None

    # learned parameters
    beta: Optional[Array] = None  # shape (M, k-1)
    W: Optional[Array] = None     # simplex matrix, shape (k-1, k)
    k: Optional[int] = None       # number of classes
    _is_fit: bool = False

    # internal
    _penalty_obj: Optional[Penalty] = None

    # ------------------------------------------------------------------
    # Core API
    # ------------------------------------------------------------------

    def fit(
        self,
        X: Array,
        y: Array,
        *,
        lr: float = 0.05,
        tol: float = 1e-7,
        max_iter: int = 10_000,
        verbose: bool = True,
    ) -> "SMLR":
        """Fit the SMLR model.

        If solver == 'fista': use proximal gradient on B with `Penalty`.
        If solver == 'bm'   : use BM factorization *only* for SEN penalty.
        """

        if self.fit_intercept:
            X = np.c_[np.ones(X.shape[0]), X]

        assert X.ndim == 2, "X should be a 2D array."
        if y.ndim >= 2:
            warnings.warn("Passed y is multidimensional; flattening.")
            y = y.ravel()

        n, m = X.shape
        classes = np.unique(y)
        self.k = int(classes.size)
        self.W = simplex_matrix(self.k)  # (k-1, k)

        # Initialize coefficients (m x (k-1))
        rng = np.random.default_rng(0)
        self.beta = 0.01 * rng.standard_normal(size=(m, self.k - 1))

        # one-hot labels
        Y = np.eye(self.k)[y]

        if self.solver == "fista":
            self._fit_fista(X, Y, lr=lr, tol=tol, max_iter=max_iter, verbose=verbose)
        elif self.solver == "bm":
            self._fit_bm(X, Y, tol=tol, max_iter=max_iter, verbose=verbose)
        else:
            raise ValueError(f"Unknown solver: {self.solver!r}")

        self._is_fit = True
        return self

    # ------------------------------------------------------------------
    # FISTA backend: proximal gradient on B (matrix β)
    # ------------------------------------------------------------------

    def _make_penalty_for_beta(self, beta_shape: tuple[int, int]) -> Penalty:
        mask = np.ones(beta_shape, dtype=bool)
        if self.fit_intercept:
            mask[0, :] = False

        if self.penalty == "sen":
            # For FISTA, SEN penalty is handled as a specialized Penalty
            # object that internally uses SVD for prox.
            from .penalties import SENPenalty  # type: ignore

            return SENPenalty(lam1=self.lam, lam2=self.sen_l2, mask=mask)

        # Otherwise use standard glmnet-style penalties
        return make_penalty(self.penalty, lam=self.lam, alpha=self.alpha, mask=mask)

    def _fit_fista(
        self,
        X: Array,
        Y: Array,
        *,
        lr: float,
        tol: float,
        max_iter: int,
        verbose: bool,
    ) -> None:
        assert self.beta is not None and self.W is not None and self.k is not None
        n, m = X.shape

        penalty = self._make_penalty_for_beta(self.beta.shape)
        self._penalty_obj = penalty

        def grad_f(beta: Array) -> Array:
            # beta: (m, k-1)
            scores = X @ beta @ self.W  # (n, k)
            probs = self._softmax(scores)
            grad = X.T @ ((probs - Y) @ self.W.T) / n
            return grad

        def f_value(beta: Array) -> float:
            scores = X @ beta @ self.W
            probs = self._softmax(scores)
            nll = -float(np.sum(Y * np.log(probs + 1e-15))) / n
            return nll

        from .solvers import FISTAOptions  # local import to avoid cycles

        options = FISTAOptions(step=lr, max_iter=max_iter, tol=tol, verbose=verbose, record_objective=False)
        beta_hat, info = fista(self.beta, grad_f, penalty, f_value=f_value, options=options)
        if verbose:
            print(f"[SMLR-FISTA] n_iter={info['n_iter']}, converged={info['converged']}")
        self.beta = beta_hat

    # ------------------------------------------------------------------
    # BM backend: Burer–Monteiro factorization for SEN only
    # ------------------------------------------------------------------

    def _fit_bm(
        self,
        X: Array,
        Y: Array,
        *,
        tol: float,
        max_iter: int,
        verbose: bool,
    ) -> None:
        if self.penalty != "sen":
            raise ValueError("BM solver is only implemented for penalty='sen'.")
        assert self.W is not None and self.k is not None

        n, m = X.shape
        m_beta = m
        n_beta = self.k - 1

        # BM works on B = β directly (shape m x (k-1)) via factorization B = U V^T.
        rank = self.rank or min(n_beta, 10)

        def f_value_B(B: Array) -> float:
            scores = X @ B @ self.W
            probs = self._softmax(scores)
            nll = -float(np.sum(Y * np.log(probs + 1e-15))) / n
            return nll

        def grad_f_B(B: Array) -> Array:
            scores = X @ B @ self.W
            probs = self._softmax(scores)
            grad = X.T @ ((probs - Y) @ self.W.T) / n
            return grad

        from .solvers import BMSENOptions  # local import

        bm_opts = BMSENOptions(
            step=1e-3,
            max_iter=max_iter,
            tol=tol,
            verbose=verbose,
            record_objective=False,
            random_state=0,
        )

        U, V, info = bm_sen_solve(
            grad_f_B=grad_f_B,
            f_value_B=f_value_B,
            m=m_beta,
            n=n_beta,
            rank=rank,
            lam1=self.lam,
            lam2=self.sen_l2,
            options=bm_opts,
        )
        if verbose:
            print(
                f"[SMLR-BM] rank={rank}, n_iter={info['n_iter']}, "
                f"converged={info['converged']}, grad_norm={info['final_grad_norm']:.3e}"
            )
        self.beta = U @ V.T

    # ------------------------------------------------------------------
    # Prediction API
    # ------------------------------------------------------------------

    @staticmethod
    def _softmax(scores: Array) -> Array:
        scores = scores - np.max(scores, axis=1, keepdims=True)
        exp_s = np.exp(scores)
        return exp_s / np.sum(exp_s, axis=1, keepdims=True)

    def _check_is_fit(self) -> None:
        if not self._is_fit or self.beta is None or self.W is None:
            raise RuntimeError("Call .fit() before using the model.")

    def predict_proba(self, X: Array) -> Array:
        self._check_is_fit()
        if self.fit_intercept:
            X = np.c_[np.ones(X.shape[0]), X]
        scores = X @ self.beta @ self.W  # type: ignore[arg-type]
        return self._softmax(scores)

    def predict(self, X: Array) -> Array:
        probs = self.predict_proba(X)
        return np.argmax(probs, axis=1)

    # ------------------------------------------------------------------

    def _get_params(self) -> Dict[str, Any]:
        return {
            "beta": self.beta,
            "lam": self.lam,
            "penalty": self.penalty,
            "alpha": self.alpha,
            "sen_l2": self.sen_l2,
            "solver": self.solver,
            "fit_intercept": self.fit_intercept,
            "rank": self.rank,
            "_is_fit": self._is_fit,
        }
