# test_bm_sen.py
import unittest
import numpy as np

from smlr.solvers import bm_sen_solve, BMSENOptions


def sen_prox(B: np.ndarray, lam1: float, lam2: float, step: float = 1.0) -> np.ndarray:
    """SEN prox，用于测试中的“金标准”解。

    求解：
        prox_{step * [ lam1 ||·||_* + (lam2/2)||·||^2 ]}(B).

    注意：这里只在 unittest 里用 SVD，不属于 BM solver。
    """
    if step <= 0.0 or (lam1 == 0.0 and lam2 == 0.0):
        return B

    # 有 lam2 时，等效变换： prox = argmin_X 0.5/step ||X - B||^2 + lam1||X||_* + (lam2/2)||X||^2
    # 标准技巧：先对 B 做一个缩放，然后对奇异值做 soft-thresholding。
    if lam2 > 0.0:
        scale = 1.0 / (1.0 + step * lam2)
        B_eff = B * scale
        step_eff = step * scale
    else:
        B_eff = B
        step_eff = step

    U, s, Vt = np.linalg.svd(B_eff, full_matrices=False)
    s_thr = np.maximum(s - step_eff * lam1, 0.0)
    if np.all(s_thr == 0.0):
        return np.zeros_like(B)
    return (U * s_thr) @ Vt


def objective_F(B: np.ndarray, B_true: np.ndarray, lam1: float, lam2: float) -> float:
    """F(B) = 0.5||B - B_true||^2 + (lam2/2)||B||^2 + lam1||B||_*."""
    R = B - B_true
    smooth = 0.5 * float(np.sum(R * R)) + 0.5 * lam2 * float(np.sum(B * B))
    s = np.linalg.svd(B, compute_uv=False)
    nuc = float(np.sum(s))
    return smooth + lam1 * nuc


class TestBMSENQuadratic(unittest.TestCase):
    def setUp(self) -> None:
        self.rng = np.random.default_rng(0)

    def test_bm_sen_quadratic_approximates_convex_solution(self) -> None:
        """BM-SEN 在简单二次 + SEN 目标上应接近凸问题解."""

        rng = self.rng
        m, n = 30, 20
        rank_true = 5
        rank_factor = 6

        lam1 = 0.1
        lam2 = 0.2

        # 1. 构造低秩 B_true
        U0 = rng.standard_normal(size=(m, rank_true))
        V0 = rng.standard_normal(size=(n, rank_true))
        B_true = U0 @ V0.T

        # 2. 定义 f(B) 和其梯度：f(B) = 0.5||B - B_true||^2
        def f_value_B(B: np.ndarray) -> float:
            R = B - B_true
            return 0.5 * float(np.sum(R * R))

        def grad_f_B(B: np.ndarray) -> np.ndarray:
            return B - B_true

        # 3. 凸问题的精确 prox 解（仅用一次 SVD）
        B_star = sen_prox(B_true, lam1=lam1, lam2=lam2, step=1.0)

        # 4. BM-SEN solver 解
        options = BMSENOptions(
            step=1e-2,
            max_iter=10_000,
            tol=1e-6,
            verbose=True,
            record_objective=True,
            random_state=0,
        )
        U_hat, V_hat, info = bm_sen_solve(
            grad_f_B=grad_f_B,
            f_value_B=f_value_B,
            m=m,
            n=n,
            rank=rank_factor,
            lam1=lam1,
            lam2=lam2,
            options=options,
        )
        B_bm = U_hat @ V_hat.T

        # 5. 比较目标值和解的距离
        F_star = objective_F(B_star, B_true, lam1, lam2)
        F_bm = objective_F(B_bm, B_true, lam1, lam2)

        rel_obj_diff = abs(F_bm - F_star) / max(1.0, abs(F_star))
        rel_sol_diff = np.linalg.norm(B_bm - B_star) / max(1.0, np.linalg.norm(B_star))

        # BM 解应该和凸解足够接近（这些阈值可以根据实验微调）
        self.assertLess(rel_obj_diff, 1e-2)
        self.assertLess(rel_sol_diff, 5e-2)

        # 梯度范数应小于 tol 级别
        self.assertTrue(info["converged"])
        self.assertLess(info["final_grad_norm"], 1e-4)


if __name__ == "__main__":
    unittest.main()
