# solver_fista.py
import numpy as np
from dataclasses import dataclass
from typing import Callable, Optional, Dict, Any, Tuple

from .penalties import Penalty  # 根据你的工程改路径

Array = np.ndarray
GradFunc = Callable[[Array], Array]
ValueFunc = Callable[[Array], float]


@dataclass
class FISTAOptions:
    """FISTA 控制参数."""
    step: float = 1e-2          # 固定步长 t = 1/L
    max_iter: int = 1000
    tol: float = 1e-6           # 相对收敛阈值（基于参数变化）
    verbose: bool = False
    record_objective: bool = False


def fista(
    beta0: Array,
    grad_f: GradFunc,
    penalty: Penalty,
    *,
    f_value: Optional[ValueFunc] = None,
    options: Optional[FISTAOptions] = None,
) -> tuple[Array, Dict[str, Any]]:
    r"""通用 FISTA / 加速近端梯度法.

    求解目标：
        minimize_B   F(B) = f(B) + penalty.value(B)

    其中：
        - f 是光滑、Lipschitz 梯度的凸函数（比如二次、NLL 等）
        - penalty 为 Penalty 子类实例，封装了 λ P(B) 的 value 和 prox

    参数
    ----
    beta0   : 初始点，ndarray，向量或矩阵均可
    grad_f  : 可调用对象，给定 B 返回 ∇f(B)
    penalty : Penalty 实例（NoPenalty / L1 / L2 / ElasticNet / SEN 等）
    f_value : 可选，用于记录 FISTA 迭代中的目标值（f(B)+penalty.value(B)）
    options : FISTAOptions，控制步长、迭代次数和日志

    返回
    ----
    beta    : 优化后的解
    info    : dict，包含 'n_iter', 'converged', 以及可选 'obj_history'
    """
    if options is None:
        options = FISTAOptions()

    t_step = float(options.step)
    if t_step <= 0.0:
        raise ValueError("FISTA step size must be positive.")

    beta = beta0.copy()
    y = beta0.copy()
    t = 1.0  # Nesterov scalar

    obj_history = []
    converged = False

    def _norm(x: Array) -> float:
        return float(np.linalg.norm(x))

    for it in range(options.max_iter):
        # 梯度步：y_k -> z_k = y_k - t ∇f(y_k)
        grad = grad_f(y)
        z = y - t_step * grad

        # 近端步：β_{k+1} = prox_{t λP}(z_k)
        beta_next = penalty.prox(z, step=t_step)

        # Nesterov 加速：构造 y_{k+1}
        t_next = 0.5 * (1.0 + np.sqrt(1.0 + 4.0 * t * t))
        y = beta_next + ((t - 1.0) / t_next) * (beta_next - beta)

        # 收敛检查（参数变化的相对量）
        rel_change = _norm(beta_next - beta) / max(1.0, _norm(beta))

        if options.record_objective and f_value is not None:
            obj = f_value(beta_next) + penalty.value(beta_next)
            obj_history.append(obj)

        if options.verbose and (it + 1) % 50 == 0:
            msg = f"[FISTA] iter={it+1}, rel_change={rel_change:.3e}"
            if options.record_objective and f_value is not None:
                msg += f", obj={obj_history[-1]:.6f}"
            print(msg)

        beta = beta_next
        t = t_next

        if rel_change < options.tol:
            converged = True
            if options.verbose:
                print(f"[FISTA] converged at iter={it+1}, rel_change={rel_change:.3e}")
            break

    info: Dict[str, Any] = {
        "n_iter": it + 1,
        "converged": converged,
    }
    if options.record_objective and f_value is not None:
        info["obj_history"] = np.array(obj_history)

    return beta, info


# BM Factor

GradFB = Callable[[Array], Array]
ValueFB = Callable[[Array], float]

@dataclass
class BMSENOptions:
    """Options for Burer–Monteiro SEN solver."""

    step: float = 1e-3        # 梯度步长
    max_iter: int = 5_000
    tol: float = 1e-6         # 基于梯度范数的收敛阈值
    verbose: bool = False
    record_objective: bool = False
    random_state: Optional[int] = None


def bm_sen_solve(
    grad_f_B: GradFB,
    *,
    f_value_B: Optional[ValueFB],
    m: int,
    n: int,
    rank: int,
    lam1: float,
    lam2: float = 0.0,
    options: Optional[BMSENOptions] = None,
) -> tuple[Array, Array, Dict[str, Any]]:
    r"""Burer–Monteiro 因子化求解 SEN 目标的 solver.

    优化目标（在 B 空间）为：

        F(B) = f(B) + (lam2 / 2) ||B||_F^2 + lam1 ||B||_*

    其中：
        - f(B) 光滑凸（如回归误差、NLL 等），通过 grad_f_B 提供梯度；
        - lam1, lam2 >= 0。

    在 (U, V) 因子空间中，B = U V^T，等价优化问题为：

        minimize_{U in R^{m x r}, V in R^{n x r}}
            tilde_F(U, V)
            = f(UV^T)
              + (lam2 / 2) ||U V^T||_F^2
              + (lam1 / 2)(||U||_F^2 + ||V||_F^2),

    对 U, V 做简单梯度下降：

        grad_U = G V + lam1 U,
        grad_V = G^T U + lam1 V,

    其中 G = ∇_B f(B) + lam2 B, B = U V^T。

    返回
    ----
    U, V : factor matrices, with B_hat = U @ V.T
    info : dict with keys
        - 'n_iter'
        - 'converged'
        - 'final_grad_norm'
        - optionally 'obj_history'
    """
    if options is None:
        options = BMSENOptions()

    if rank <= 0:
        raise ValueError("rank must be positive.")
    if lam1 < 0 or lam2 < 0:
        raise ValueError("lam1 and lam2 must be non-negative.")
    if m <= 0 or n <= 0:
        raise ValueError("m and n must be positive.")

    step = float(options.step)
    if step <= 0.0:
        raise ValueError("step size must be positive.")

    rng = np.random.default_rng(options.random_state)
    U = 0.01 * rng.standard_normal(size=(m, rank))
    V = 0.01 * rng.standard_normal(size=(n, rank))

    obj_history = []
    converged = False

    def grad_h(B: Array) -> Array:
        # 光滑部分梯度：∇_B f(B) + lam2 * B
        return grad_f_B(B) + lam2 * B

    def _norm2(*xs: Array) -> float:
        acc = 0.0
        for x in xs:
            acc += float(np.sum(x * x))
        return float(np.sqrt(acc))

    for it in range(options.max_iter):
        B = U @ V.T     # (m, n)
        G = grad_h(B)   # (m, n)

        # 链式法则 + lam1 / 2 ||U||^2, lam1 / 2 ||V||^2 的梯度
        grad_U = G @ V + lam1 * U      # (m, r)
        grad_V = G.T @ U + lam1 * V    # (n, r)

        U_next = U - step * grad_U
        V_next = V - step * grad_V

        grad_norm = _norm2(grad_U, grad_V)

        if options.record_objective and f_value_B is not None:
            B_next = U_next @ V_next.T
            smooth = f_value_B(B_next) + 0.5 * lam2 * float(np.sum(B_next * B_next))
            reg = 0.5 * lam1 * (float(np.sum(U_next * U_next)) +
                                float(np.sum(V_next * V_next)))
            obj_history.append(smooth + reg)

        if options.verbose and (it + 1) % 50 == 0:
            msg = f"[BM-SEN] iter={it+1}, grad_norm={grad_norm:.3e}"
            if options.record_objective and f_value_B is not None:
                msg += f", obj={obj_history[-1]:.6f}"
            print(msg)

        U, V = U_next, V_next

        if grad_norm < options.tol:
            converged = True
            if options.verbose:
                print(f"[BM-SEN] converged at iter={it+1}, grad_norm={grad_norm:.3e}")
            break

    info: Dict[str, Any] = {
        "n_iter": it + 1,
        "converged": converged,
        "final_grad_norm": grad_norm,
    }
    if options.record_objective and f_value_B is not None:
        info["obj_history"] = np.array(obj_history)

    return U, V, info


# -------------------------------------------------------------
# 工具：L-BFGS（简单实现，适合中小维度）
# -------------------------------------------------------------


@dataclass
class LBFGSOptions:
    max_iter: int = 500
    tol_grad: float = 1e-6
    m_hist: int = 10           # L-BFGS 记忆长度
    c1: float = 1e-4           # Armijo 常数
    ls_beta: float = 0.5       # 回溯线搜索缩放
    ls_max_iter: int = 20
    verbose: bool = False


def lbfgs_minimize(
    x0: Array,
    f_and_grad,
    options: Optional[LBFGSOptions] = None,
) -> Tuple[Array, Dict[str, Any]]:
    """
    简单 L-BFGS 实现：

        输入: 初始点 x0, 可调用 f_and_grad(x) -> (f, g)

        输出: (x_hat, info)

    只用于我们这里的 (U, V) 内层优化，适合中小规模。
    """
    if options is None:
        options = LBFGSOptions()

    x = x0.copy().astype(float)
    f, g = f_and_grad(x)
    k = 0

    s_list = []
    y_list = []
    rho_list = []

    def two_loop(q: Array) -> Array:
        # L-BFGS 两层循环，给出近似的 H_k g
        alpha = []
        for s, y, rho in reversed(list(zip(s_list, y_list, rho_list))):
            a = rho * np.dot(s, q)
            alpha.append(a)
            q = q - a * y
        # 初始 Hessian 近似用标量比例 I
        if len(s_list) > 0:
            y_last = y_list[-1]
            s_last = s_list[-1]
            gamma = np.dot(s_last, y_last) / np.dot(y_last, y_last)
        else:
            gamma = 1.0
        r = gamma * q
        for (s, y, rho, a) in zip(s_list, y_list, rho_list, reversed(alpha)):
            beta = rho * np.dot(y, r)
            r = r + s * (a - beta)
        return r

    info: Dict[str, Any] = {}
    for k in range(options.max_iter):
        g_norm = float(np.linalg.norm(g))
        if g_norm < options.tol_grad:
            if options.verbose:
                print(f"[L-BFGS] converged at iter={k+1}, ||g||={g_norm:.3e}")
            break

        # 计算搜索方向 p_k = - H_k g_k
        if len(s_list) == 0:
            p = -g
        else:
            p = -two_loop(g)

        # Armijo 回溯线搜索
        step = 1.0
        f0 = f
        gTp = float(np.dot(g, p))
        if gTp >= 0:
            # 如果方向不是下降方向，退回到负梯度
            p = -g
            gTp = -float(np.dot(g, g))

        for _ in range(options.ls_max_iter):
            x_new = x + step * p
            f_new, g_new = f_and_grad(x_new)
            if f_new <= f0 + options.c1 * step * gTp:
                break
            step *= options.ls_beta

        s = x_new - x
        y = g_new - g
        ys = float(np.dot(y, s))

        if ys > 1e-12:
            if len(s_list) == options.m_hist:
                s_list.pop(0)
                y_list.pop(0)
                rho_list.pop(0)
            s_list.append(s)
            y_list.append(y)
            rho_list.append(1.0 / ys)

        x, f, g = x_new, f_new, g_new

    info["n_iter"] = k + 1
    info["final_grad_norm"] = float(np.linalg.norm(g))
    info["f_final"] = float(f)
    info["converged"] = info["final_grad_norm"] < options.tol_grad
    return x, info


# -------------------------------------------------------------
# 工具：Power Method 求谱范数和主奇异向量
# -------------------------------------------------------------


def top_singular_triplet_power(
    G: Array,
    n_iter: int = 100,
    tol: float = 1e-6,
) -> Tuple[float, Array, Array]:
    """
    Power method 求 G 的最大奇异值及对应左右奇异向量 (σ1, u1, v1).

    迭代：
        v_{k+1} ∝ G^T u_k
        u_{k+1} ∝ G v_{k+1}
    """
    m, n = G.shape
    rng = np.random.default_rng(0)
    v = rng.standard_normal(size=(n,))
    v /= np.linalg.norm(v)

    sigma_old = 0.0
    for _ in range(n_iter):
        u = G @ v
        u_norm = np.linalg.norm(u)
        if u_norm == 0:
            return 0.0, np.zeros(m), np.zeros(n)
        u /= u_norm

        v = G.T @ u
        v_norm = np.linalg.norm(v)
        if v_norm == 0:
            return 0.0, np.zeros(m), np.zeros(n)
        v /= v_norm

        sigma = float(u @ (G @ v))
        if abs(sigma - sigma_old) < tol * max(1.0, abs(sigma_old)):
            break
        sigma_old = sigma

    # 最后再算一次 σ = ||G v||
    u = G @ v
    sigma = float(np.linalg.norm(u))
    if sigma > 0:
        u /= sigma
    return sigma, u, v


# -------------------------------------------------------------
# Rank-adaptive BM-SEN for quadratic SEN:
#
#   h(B) = 0.5 ||B - B_true||^2 + (lam2/2)||B||^2
#   F(B) = h(B) + lam1 ||B||_*
#
# 在 (U,V) 空间优化：
#   F_r(U,V) = h(UV^T) + (lam1/2)(||U||^2 + ||V||^2)
#
# rank 自适应 + 谱范数证书：
#   G = ∇h(B_hat) = (B_hat - B_true) + lam2 * B_hat
#   σ1(G) <= lam1(1+eps) => 全局最优
# -------------------------------------------------------------


@dataclass
class RankAdaptiveOptions:
    max_rank: int = 10
    lbfgs_options: LBFGSOptions = LBFGSOptions()
    eps_cert: float = 1e-2      # 证书松弛参数 ε
    alpha_init: float = 1e-2    # 新增秩方向的初始化尺度
    verbose: bool = False


def rank_adaptive_bm_sen_quadratic(
    B_true: Array,
    lam1: float,
    lam2: float,
    options: Optional[RankAdaptiveOptions] = None,
) -> Tuple[Array, Dict[str, Any]]:
    """
    Rank-adaptive BM-SEN，用于测试的简单二次 SEN 问题：

        h(B) = 0.5 ||B - B_true||^2 + (lam2/2)||B||^2
        F(B) = h(B) + lam1 ||B||_*

    输入：
        B_true : m x n 目标矩阵
        lam1, lam2 >= 0
    输出：
        B_hat  : 近似全局最优 B
        info   : 包含 'rank', 'certificate_sigma', 'n_outer', 等
    """
    if options is None:
        options = RankAdaptiveOptions()

    m, n = B_true.shape
    lam1 = float(lam1)
    lam2 = float(lam2)

    r = 1
    rng = np.random.default_rng(0)

    def h_value(B: Array) -> float:
        R = B - B_true
        return 0.5 * float(np.sum(R * R)) + 0.5 * lam2 * float(np.sum(B * B))

    def grad_h(B: Array) -> Array:
        return (B - B_true) + lam2 * B

    U = 0.01 * rng.standard_normal(size=(m, r))
    V = 0.01 * rng.standard_normal(size=(n, r))

    outer_iter = 0
    info: Dict[str, Any] = {}

    while True:
        outer_iter += 1
        if options.verbose:
            print(f"\n[Rank-adaptive BM-SEN] outer iter {outer_iter}, rank = {r}")

        # --------- Step 2: 在当前 rank r 上用 L-BFGS 解 F_r(U,V) ----------
        # 变量向量化: theta = [vec(U); vec(V)]
        def pack(U_: Array, V_: Array) -> Array:
            return np.concatenate([U_.ravel(), V_.ravel()])

        def unpack(theta: Array) -> Tuple[Array, Array]:
            U_vec = theta[: m * r]
            V_vec = theta[m * r :]
            U_ = U_vec.reshape(m, r)
            V_ = V_vec.reshape(n, r)
            return U_, V_

        def f_and_grad_theta(theta: Array) -> Tuple[float, Array]:
            U_, V_ = unpack(theta)
            B_ = U_ @ V_.T
            R = B_ - B_true
            h = 0.5 * float(np.sum(R * R)) + 0.5 * lam2 * float(np.sum(B_ * B_))
            reg = 0.5 * lam1 * (
                float(np.sum(U_ * U_)) + float(np.sum(V_ * V_))
            )
            f = h + reg

            G = grad_h(B_)  # (m, n)
            grad_U = G @ V_ + lam1 * U_
            grad_V = G.T @ U_ + lam1 * V_

            grad_theta = np.concatenate([grad_U.ravel(), grad_V.ravel()])
            return f, grad_theta

        theta0 = pack(U, V)
        theta_hat, info_inner = lbfgs_minimize(theta0, f_and_grad_theta, options=options.lbfgs_options)
        U, V = unpack(theta_hat)
        B_hat = U @ V.T

        # --------- Step 3: global optimality certificate ----------
        G = grad_h(B_hat)  # ∇h(B_hat)
        sigma1, u1, v1 = top_singular_triplet_power(G)

        if options.verbose:
            print(
                f"[Rank-adaptive BM-SEN] rank={r}, "
                f"sigma1(G)={sigma1:.4f}, lam1={lam1:.4f}"
            )

        if sigma1 <= lam1 * (1.0 + options.eps_cert):
            # 证书通过，认为 B_hat 是全局最优
            info["rank"] = r
            info["certificate_sigma"] = sigma1
            info["n_outer"] = outer_iter
            info["inner_info"] = info_inner
            info["converged"] = True
            info["h_value"] = h_value(B_hat)
            return B_hat, info

        # --------- Step 4: rank 不够，扩展 rank ----------
        if r >= options.max_rank:
            # 達到 rank 上限，只能退出
            if options.verbose:
                print("[Rank-adaptive BM-SEN] reached max_rank, stop.")
            info["rank"] = r
            info["certificate_sigma"] = sigma1
            info["n_outer"] = outer_iter
            info["inner_info"] = info_inner
            info["converged"] = False
            info["h_value"] = h_value(B_hat)
            return B_hat, info

        # 用 -G 的主奇异向量方向扩展 U, V
        # 算法里建议用 u1, v1 对应 -G；我们在 power method 得到的是 G 的 (σ1,u1,v1)，
        # 对 -G 只需要把 σ1 变成 -σ1，向量本身不变，所以可以直接用 u1, v1。
        r += 1
        U_new = np.zeros((m, r))
        V_new = np.zeros((n, r))
        U_new[:, :-1] = U
        V_new[:, :-1] = V
        alpha = options.alpha_init
        U_new[:, -1] = alpha * u1
        V_new[:, -1] = alpha * v1
        U, V = U_new, V_new

