# test_fista.py
import unittest
import numpy as np

from smlr.penalties import NoPenalty, L1Penalty, L2Penalty
from smlr.solvers import fista, FISTAOptions


class TestFISTAQuadratic(unittest.TestCase):
    def setUp(self) -> None:
        self.rng = np.random.default_rng(0)

    # 1) 无正则：min 0.5||x - a||^2
    def test_no_penalty_quadratic(self) -> None:
        d = 5
        a = self.rng.standard_normal(size=d)

        # f(x) = 0.5||x - a||^2, grad f(x) = x - a
        def grad_f(x: np.ndarray) -> np.ndarray:
            return x - a

        def f_val(x: np.ndarray) -> float:
            return 0.5 * float(np.sum((x - a) ** 2))

        x0 = np.zeros_like(a)
        penalty = NoPenalty()

        options = FISTAOptions(step=1.0, max_iter=50, tol=1e-10, verbose=True)
        x_hat, info = fista(x0, grad_f, penalty, f_value=f_val, options=options)

        # 真解 = a
        self.assertTrue(info["converged"])
        self.assertLess(np.linalg.norm(x_hat - a), 1e-6)

    # 2) L2 正则：min 0.5||x - a||^2 + (lam/2)||x||^2
    def test_l2_penalty_quadratic(self) -> None:
        d = 7
        a = self.rng.standard_normal(size=d)
        lam = 0.5

        # f(x) = 0.5||x - a||^2
        def grad_f(x: np.ndarray) -> np.ndarray:
            return x - a

        def f_val(x: np.ndarray) -> float:
            return 0.5 * float(np.sum((x - a) ** 2))

        x0 = np.zeros_like(a)
        penalty = L2Penalty(lam=lam, mask=None)  # P(x) = (lam/2)||x||^2

        options = FISTAOptions(step=1.0, max_iter=100, tol=1e-10, verbose=True)
        x_hat, info = fista(x0, grad_f, penalty, f_value=f_val, options=options)

        # 真解：x* = a / (1 + lam) （我们前面推过）
        x_star = a / (1.0 + lam)

        self.assertTrue(info["converged"])
        self.assertLess(np.linalg.norm(x_hat - x_star), 1e-6)

    # 3) L1 正则：min 0.5||x - a||^2 + lam||x||_1
    def test_l1_penalty_quadratic(self) -> None:
        d = 7
        a = self.rng.standard_normal(size=d)
        lam = 0.3

        # f(x) = 0.5||x - a||^2
        def grad_f(x: np.ndarray) -> np.ndarray:
            return x - a

        def f_val(x: np.ndarray) -> float:
            return 0.5 * float(np.sum((x - a) ** 2))

        x0 = np.zeros_like(a)
        penalty = L1Penalty(lam=lam, mask=None)  # P(x) = lam ||x||_1

        # 对这个目标，真解恰好是 soft-thresholding(a, lam)
        def soft_threshold(z: np.ndarray, t: float) -> np.ndarray:
            return np.sign(z) * np.maximum(np.abs(z) - t, 0.0)

        x_star = soft_threshold(a, lam)

        options = FISTAOptions(step=1.0, max_iter=200, tol=1e-8, verbose=True)
        x_hat, info = fista(x0, grad_f, penalty, f_value=f_val, options=options)

        self.assertTrue(info["converged"])
        self.assertLess(np.linalg.norm(x_hat - x_star), 1e-5)
    
    def test_matrix_quadratic_ridge(self) -> None:
        rng = self.rng
        p = 40      # 行数（相当于样本数）
        m = 20      # B 的行数
        n = 10      # B 的列数

        A = rng.standard_normal(size=(p, m))
        B_true = rng.standard_normal(size=(m, n))
        noise = 0.1 * rng.standard_normal(size=(p, n))
        C = A @ B_true + noise

        lam = 0.5

        # f(B) = 0.5 ||A B - C||_F^2
        def grad_f(B: np.ndarray) -> np.ndarray:
            R = A @ B - C
            return A.T @ R

        def f_val(B: np.ndarray) -> float:
            R = A @ B - C
            return 0.5 * float(np.sum(R * R))

        B0 = np.zeros_like(B_true)

        penalty = L2Penalty(lam=lam, mask=None)  # g(B) = (lam/2)||B||_F^2

        # 计算 Lipschitz 常数 L = lambda_max(A^T A)
        AtA = A.T @ A
        L = float(np.linalg.eigvalsh(AtA).max())
        step = 1.0 / L  # 标准梯度步长

        options = FISTAOptions(
            step=step,
            max_iter=500,
            tol=1e-6,
            verbose=True,
            record_objective=True,
        )
        B_hat, info = fista(B0, grad_f, penalty, f_value=f_val, options=options)

        # 闭式解：B_star = (A^T A + lam I)^(-1) A^T C
        G = AtA + lam * np.eye(m)
        B_star = np.linalg.solve(G, A.T @ C)

        err = np.linalg.norm(B_hat - B_star) / max(1.0, np.linalg.norm(B_star))

        # 1) 误差应该很小（逼近闭式解）
        self.assertLess(err, 1e-4)

        # 2) 迭代次数不应太小（确保 FISTA 没有“一步解完”）
        self.assertGreater(info["n_iter"], 10)

