from __future__ import annotations

import numpy as np
from abc import ABC, abstractmethod
from typing import Optional

Array = np.ndarray


class Penalty(ABC):
    """Abstract base class for penalties.

    A penalty object only knows how to
    - compute its value λ P(β)
    - apply its proximal operator: prox_{step * λ P}(β).

    It does *not* know anything about the likelihood / SMLR structure.
    """

    def __init__(self, lam: float, mask: Optional[Array] = None) -> None:
        if lam < 0:
            raise ValueError("lam must be non-negative.")
        self.lam = float(lam)

        if mask is not None and not isinstance(mask, np.ndarray):
            raise TypeError("mask must be a numpy array or None.")
        if mask is not None and mask.dtype is not bool:
            mask = mask.astype(bool)
        # mask has the same shape as beta; True entries are penalized
        self.mask = mask

    # ------------------------------------------------------------------
    # utilities
    # ------------------------------------------------------------------
    def _apply_mask(self, beta: Array) -> Array:
        """Return masked copy of beta for value() computation."""
        if self.mask is None:
            return beta
        return beta * self.mask

    @staticmethod
    def _soft_threshold(z: Array, t: float) -> Array:
        """Elementwise soft thresholding S(z, t)."""
        if t <= 0:
            return z
        return np.sign(z) * np.maximum(np.abs(z) - t, 0.0)

    # ------------------------------------------------------------------
    # abstract interface
    # ------------------------------------------------------------------
    @abstractmethod
    def value(self, beta: Array) -> float:
        """Return the penalty value λ P(β)."""

    @abstractmethod
    def prox(self, beta: Array, step: float) -> Array:
        """Proximal operator for step * λ * P(β).

        Returns
        -------
        beta_new : ndarray
            argmin_B 0.5/step ||B - beta||_F^2 + λ P(B)
        """


class NoPenalty(Penalty):
    """No regularization."""

    def __init__(self) -> None:
        super().__init__(lam=0.0, mask=None)

    def value(self, beta: Array) -> float:  # noqa: ARG002
        return 0.0

    def prox(self, beta: Array, step: float) -> Array:  # noqa: ARG002
        return beta


class L1Penalty(Penalty):
    r"""Element-wise L1 penalty: λ ||β||_1."""

    def value(self, beta: Array) -> float:
        b = self._apply_mask(beta)
        return self.lam * float(np.sum(np.abs(b)))

    def prox(self, beta: Array, step: float) -> Array:
        """Soft-thresholding on penalized entries."""
        if self.lam == 0.0 or step <= 0.0:
            return beta

        if self.mask is None:
            return self._soft_threshold(beta, step * self.lam)

        out = beta.copy()
        idx = self.mask
        out[idx] = self._soft_threshold(beta[idx], step * self.lam)
        return out


class L2Penalty(Penalty):
    r"""L2/Frobenius penalty: (λ/2) ||β||_F^2."""

    def value(self, beta: Array) -> float:
        b = self._apply_mask(beta)
        return 0.5 * self.lam * float(np.sum(b * b))

    def prox(self, beta: Array, step: float) -> Array:
        """Shrink penalized entries by 1 / (1 + step * λ)."""
        if self.lam == 0.0 or step <= 0.0:
            return beta

        if self.mask is None:
            return beta / (1.0 + step * self.lam)

        out = beta.copy()
        out[self.mask] /= (1.0 + step * self.lam)
        return out


class ElasticNetPenalty(Penalty):
    r"""Elastic net penalty:

    λ [ α ||β||_1 + (1 - α)/2 ||β||_F^2 ].

    This is exactly glmnet's P_alpha(β), parameterized by (lam, alpha).
    """

    def __init__(self, lam: float, alpha: float, mask: Optional[Array] = None) -> None:
        if not (0.0 <= alpha <= 1.0):
            raise ValueError("alpha must be in [0, 1].")
        super().__init__(lam=lam, mask=mask)
        self.alpha = float(alpha)
        # convenient decomposition: λ1 = λ α, λ2 = λ (1 - α)
        self.l1 = self.lam * self.alpha
        self.l2 = self.lam * (1.0 - self.alpha)

    def value(self, beta: Array) -> float:
        b = self._apply_mask(beta)
        l1 = float(np.sum(np.abs(b)))
        l2 = float(np.sum(b * b))
        return self.l1 * l1 + 0.5 * self.l2 * l2

    def prox(self, beta: Array, step: float) -> Array:
        """Soft-threshold then L2 shrink on penalized entries."""
        if self.lam == 0.0 or step <= 0.0:
            return beta

        if self.mask is None:
            z = self._soft_threshold(beta, step * self.l1)
            return z / (1.0 + step * self.l2)

        out = beta.copy()
        idx = self.mask
        z = self._soft_threshold(beta[idx], step * self.l1)
        out[idx] = z / (1.0 + step * self.l2)
        return out


class SENPenalty(Penalty):
    r"""Spectral elastic net (SEN) penalty for matrices B:

    λ1 ||B||_* + (λ2 / 2) ||B||_F^2.

    This matches the SEN–SMLR objective
        Ln(B) + λ1 ||B||_* + (λ2 / 2) ||B||_F^2
    where ||·||_* is the nuclear norm and ||·||_F is the Frobenius norm.

    But actually, we don't really need this penalty class because we'll use 
    bm factor method to avoid large-scale SVD.
    """

    def __init__(self, lam1: float, lam2: float = 0.0) -> None:
        if lam1 < 0.0 or lam2 < 0.0:
            raise ValueError("lam1 and lam2 must be non-negative.")
        
        # nuclear norm is global; we intentionally ignore any mask
        super().__init__(lam=0.0, mask=None)
        self.lam1 = float(lam1)
        self.lam2 = float(lam2)

    def value(self, beta: Array) -> float:
        """Return λ1 ||B||_* + (λ2/2) ||B||_F^2."""
        if self.lam1 == 0.0 and self.lam2 == 0.0:
            return 0.0

        # nuclear norm: sum of singular values
        s = np.linalg.svd(beta, compute_uv=False)
        nuc = float(np.sum(s))
        fro2 = float(np.sum(beta * beta))
        return self.lam1 * nuc + 0.5 * self.lam2 * fro2

    def prox(self, beta: Array, step: float) -> Array:
        """Proximal operator for SEN penalty.

        Solve
            argmin_B 0.5/step ||B - beta||_F^2
                     + lam1 ||B||_* + (lam2/2) ||B||_F^2.

        Closed form:
          - If lam1 = 0: ridge shrinkage beta / (1 + step * lam2).
          - Else: rescale beta and apply singular value thresholding
            with an effective step size.
        """
        if step <= 0.0 or (self.lam1 == 0.0 and self.lam2 == 0.0):
            return beta

        # Pure ridge case: no nuclear norm
        if self.lam1 == 0.0:
            return beta / (1.0 + step * self.lam2)

        # General case: nuclear + ridge
        if self.lam2 > 0.0:
            # Combine the two quadratic terms into a single scaled
            # proximal problem for the nuclear norm.
            # Effective parameters:
            #   beta_eff = beta / (1 + step * lam2)
            #   step_eff = step / (1 + step * lam2)
            scale = 1.0 / (1.0 + step * self.lam2)
            beta_eff = beta * scale
            step_eff = step * scale
        else:
            beta_eff = beta
            step_eff = step

        # Singular value thresholding on beta_eff
        U, s, Vt = np.linalg.svd(beta_eff, full_matrices=False)
        s_thr = np.maximum(s - step_eff * self.lam1, 0.0)

        if np.all(s_thr == 0.0):
            return np.zeros_like(beta)

        # Reconstruct with shrunk singular values
        return (U * s_thr) @ Vt


def make_penalty(
    name: Optional[str],
    lam: float,
    *,
    alpha: Optional[float] = None,
    mask: Optional[Array] = None,
    lam2: Optional[float] = None,
) -> Penalty:
    """Factory for penalty objects.

    Parameters
    ----------
    name : {None, 'l1', 'l2', 'elasticnet', 'sen'}
        Penalty type. If None, returns NoPenalty().
    lam : float
        Global regularization strength.
        - For 'l1'/'l2': this is λ.
        - For 'elasticnet': this is λ in P_alpha.
        - For 'sen': this is λ1 (nuclear part).
    alpha : float, optional
        Elastic net mixing parameter in [0, 1].
        Only used if name == 'elasticnet'.
    mask : ndarray[bool], optional
        Same shape as β; True entries are penalized
        (e.g. exclude intercept row by setting mask[0, :] = False).
        Ignored for 'sen' (nuclear norm is global).
    lam2 : float, optional
        For 'sen': λ2 (ridge part) in (λ2/2)||B||_F^2.
        If None, defaults to 0.0 (pure nuclear norm).

    Returns
    -------
    penalty : Penalty
    """

    if name is None:
        return NoPenalty()

    name = name.lower()

    if name == "l1":
        return L1Penalty(lam=lam, mask=mask)
    if name == "l2":
        return L2Penalty(lam=lam, mask=mask)
    if name == "elasticnet":
        if alpha is None:
            raise ValueError("alpha must be provided for elasticnet.")
        return ElasticNetPenalty(lam=lam, alpha=alpha, mask=mask)
    if name == "sen":
        lam2_val = 0.0 if lam2 is None else float(lam2)
        return SENPenalty(lam1=lam, lam2=lam2_val)

    raise ValueError(f"Unknown penalty: {name!r}")


if __name__ == '__main__':
    # l1 test
    beta = np.array([
        [1.0, 2.0],
        [0.0, 3.0]
    ])

    l1 = L1Penalty(lam=0.5)
    print('l1 penalty:', l1.value(beta))

    l2 = L2Penalty(lam=0.5)
    print('l2 penalty:', l2.value(beta))
    
    elasticnet = ElasticNetPenalty(lam=2.0, alpha=0.5)
    print('ElasticNet penalty:', elasticnet.value(beta))

    sen = SENPenalty(lam1=0.5, lam2=0.5)
    print('SEN penalty:', sen.value(beta))

