"""
Multi-Modal Error Signatures Report Generator

Functions: Generate 6 independent analysis plots
(a) Spectrum - Spectrum plot
(b) Time-domain error - Time-domain error
(c) Phase-domain error - Phase-domain error (polar coordinates)
(d) Code overflow - Code overflow distribution
(e) Error histogram by phase - Phase-domain error histogram
(f) Error histogram by code - Code-domain error histogram

Author: ADC Toolbox
Date: 2025-11-21
"""

import numpy as np
import matplotlib.pyplot as plt
import os


def plot_a_spectrum(data, fs=1e6, num_bits=12, fin=None, output_path='fig_a_spectrum.png'):
    """
    Figure (a): Spectrum

    Parameters:
        data: ADC output data
        fs: Sampling frequency (Hz)
        num_bits: ADC bit width
        fin: Normalized input frequency (auto-detect if None)
        output_path: Output file path
    """
    from .analyze_spectrum import analyze_spectrum
    from adctoolbox.common.sineFit import sine_fit
    from adctoolbox.common.findBin import find_bin

    data = np.asarray(data).flatten()
    N = len(data)

    # Auto-detect frequency (using sine_fit + find_bin, consistent with run_all_tests.py)
    if fin is None or fin == 0:
        _, freq_est, _, _, _ = sine_fit(data)
        fin = find_bin(1, freq_est, N) / N

    # Call spec_plot (parameters consistent with run_all_tests.py)
    # isPlot=1 generates the plot, then we save it
    result = analyze_spectrum(
        data, Fs=1.0, harmonic=0, label=1, OSR=1, isPlot=1
    )

    # Update title to match multimodal format
    plt.title('(a) Spectrum\nOutput Spectrum', fontsize=14, fontweight='bold', loc='left')

    # Save the figure generated by spec_plot
    plt.savefig(output_path, dpi=300, bbox_inches='tight')
    plt.close()
    print(f"  [OK] Saved: {output_path}")

    return output_path


def plot_b_time_domain_error(data, fin, num_bits=12, output_path='fig_b_time_error.png'):
    """
    Figure (b): Time-domain error

    Parameters:
        data: ADC output data
        fin: Normalized input frequency
        num_bits: ADC bit width
        output_path: Output file path
    """
    from adctoolbox.aout.tomDecomp import tomDecomp

    data = np.asarray(data).flatten()
    N = len(data)

    # Thompson decomposition (order=50, consistent with run_all_tests.py)
    signal, error, indep, dep, phi = tomDecomp(data, fin, order=50, disp=0)

    # Calculate display range (show ~1.5 cycles, consistent with run_all_tests.py)
    xlim = min(max(int(1.5 / fin), 100), N)

    # Create dual Y-axis plot (figsize=(12, 6), consistent with run_all_tests.py)
    fig, ax1 = plt.subplots(figsize=(12, 6))
    ax2 = ax1.twinx()

    # Left Y-axis: signal
    ax1.plot(data[:xlim], 'kx', markersize=3, alpha=0.5, label='data')
    ax1.plot(signal[:xlim], '-', color='gray', linewidth=1.5, label='signal')
    ax1.set_xlim([0, xlim])
    ax1.set_ylabel('Signal', fontsize=12)
    ax1.tick_params(axis='y', labelcolor='k')

    # Right Y-axis: error
    ax2.plot(dep[:xlim], 'r-', label='dep', linewidth=1.5)
    ax2.plot(indep[:xlim], 'b-', label='indep', linewidth=1)
    ax2.set_ylabel('Error', fontsize=12)
    ax2.tick_params(axis='y', labelcolor='r')

    ax1.set_xlabel('Samples', fontsize=12)
    ax1.set_title('(b) Time-domain error', fontsize=14, fontweight='bold', loc='left')

    # Merge legends
    ax1.legend(loc='upper left', fontsize=10)
    ax2.legend(loc='upper right', fontsize=10)
    ax1.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.savefig(output_path, dpi=300, bbox_inches='tight')
    plt.close()
    print(f"  [OK] Saved: {output_path}")

    return output_path


def plot_c_phase_domain_error(data, fin, num_bits=12, output_path='fig_c_phase_error.png'):
    """
    Figure (c): Phase-domain error (polar coordinates)

    Parameters:
        data: ADC output data
        fin: Normalized input frequency
        num_bits: ADC bit width
        output_path: Output file path
    """
    from adctoolbox.aout.tomDecomp import tomDecomp

    data = np.asarray(data).flatten()
    N = len(data)

    # Thompson decomposition to get error (order=50, consistent with run_all_tests.py)
    signal, error, indep, dep, phi = tomDecomp(data, fin, order=50, disp=0)

    # Calculate phase
    t = np.arange(N)
    phase_rad = (phi + 2 * np.pi * fin * t) % (2 * np.pi)

    # Create polar plot (log scale on radial axis for better error visualization)
    fig, ax = plt.subplots(figsize=(8, 8), subplot_kw=dict(projection='polar'))

    # Use absolute error for log scale
    error_abs = np.abs(error)
    # Filter out zeros to avoid log(0)
    error_abs_nonzero = error_abs[error_abs > 0]
    if len(error_abs_nonzero) > 0:
        error_min = np.min(error_abs_nonzero)
    else:
        error_min = 1e-10

    # Replace zeros with small value for plotting
    error_abs_plot = np.where(error_abs > 0, error_abs, error_min)

    # Draw scatter plot (log scale)
    ax.scatter(phase_rad, error_abs_plot, c='blue', s=0.5, alpha=0.3)
    ax.set_title('(c) Phase-domain error\nSpectrum Phase (Log Scale)',
                fontsize=14, fontweight='bold', pad=20)
    ax.set_theta_zero_location('N')
    ax.set_theta_direction(1)
    ax.set_yscale('log')  # Log scale on radial axis

    # Set radial limits
    if len(error_abs_nonzero) > 0:
        r_min = np.min(error_abs_nonzero) * 0.5
        r_max = np.max(error_abs) * 2
        ax.set_ylim([r_min, r_max])

    plt.savefig(output_path, dpi=300, bbox_inches='tight')
    plt.close()
    print(f"  [OK] Saved: {output_path}")

    return output_path


def plot_d_code_overflow(data, num_bits=12, output_path='fig_d_code_overflow.png'):
    """
    Figure (d): Code overflow

    Parameters:
        data: ADC output data
        num_bits: ADC bit width
        output_path: Output file path
    """
    data = np.asarray(data).flatten()
    N = len(data)

    # Normalize and convert to integer codes
    data_min = np.min(data)
    data_max = np.max(data)
    normalized = (data - data_min) / (data_max - data_min)
    codes = np.clip(normalized * (2**num_bits - 1), 0, 2**num_bits - 1).astype(int)

    # Calculate distribution for each bit
    bit_distribution = np.zeros(num_bits)
    for bit_idx in range(num_bits):
        mask = (codes >> bit_idx) & 1
        bit_distribution[bit_idx] = np.sum(mask) / N * 100

    # Reverse order (MSB on left)
    bit_labels = [f'{num_bits - 1 - i}' for i in range(num_bits)]
    x_pos = np.arange(num_bits)
    reversed_dist = bit_distribution[::-1]

    # Plot
    fig, ax = plt.subplots(figsize=(10, 6))
    bars = ax.bar(x_pos, reversed_dist, color='royalblue', edgecolor='navy', linewidth=1.5)

    # Add percentage labels
    for bar, pct in zip(bars, reversed_dist):
        height = bar.get_height()
        ax.text(bar.get_x() + bar.get_width()/2., height + 0.5,
               f'{pct:.1f}%', ha='center', va='bottom', fontsize=8)

    ax.set_xlabel('bit', fontsize=12)
    ax.set_ylabel('Relative Distribution', fontsize=12)
    ax.set_title('(d) Code overflow', fontsize=14, fontweight='bold', loc='left')
    ax.set_xticks(x_pos)
    ax.set_xticklabels(bit_labels)
    ax.set_ylim([0, 105])
    ax.grid(axis='y', alpha=0.3)

    plt.savefig(output_path, dpi=300, bbox_inches='tight')
    plt.close()
    print(f"  [OK] Saved: {output_path}")

    return output_path


def plot_e_error_hist_by_phase(data, fin, num_bits=12, output_path='fig_e_error_hist_phase.png'):
    """
    Figure (e): Error histogram by phase

    Parameters:
        data: ADC output data
        fin: Normalized input frequency
        num_bits: ADC bit width
        output_path: Output file path
    """
    from adctoolbox.aout.tomDecomp import tomDecomp
    from adctoolbox.aout.errHistSine import errHistPhase

    data = np.asarray(data).flatten()
    N = len(data)

    # Thompson decomposition to get error and phase (order=50, consistent with run_all_tests.py)
    signal, error, indep, dep, phi = tomDecomp(data, fin, order=50, disp=0)

    # Scale data and call errHistPhase (consistent with run_all_tests.py)
    # Note: errHistSine expects scaled data, use 2^12 as standard
    scaled = data * (2**12)
    emean_phase, erms_phase, phase_bins, _ = errHistPhase(
        scaled, bin_count=1000, fin=fin, disp=0
    )

    # Calculate phase (degrees)
    t = np.arange(N)
    phi_list = (phi / np.pi * 180 + t * fin * 360) % 360

    # Create plot (top and bottom parts)
    fig = plt.figure(figsize=(10, 8))

    # Top part: error scatter and mean
    ax1 = plt.subplot(2, 1, 1)
    ax1.plot(phi_list, error, 'r.', markersize=0.5, alpha=0.3, label='error')
    ax1.plot(phase_bins, emean_phase, 'b-', linewidth=2, label='mean')
    ax1.set_ylabel('error', fontsize=12)
    ax1.set_title('(e) Error histogram by phase', fontsize=14, fontweight='bold', loc='left')
    ax1.set_xlim([0, 360])
    ax1.legend(fontsize=10)
    ax1.grid(True, alpha=0.3)

    # Bottom part: RMS bar chart
    ax2 = plt.subplot(2, 1, 2)
    ax2.bar(phase_bins, erms_phase, width=1, color='skyblue', edgecolor='navy', linewidth=0.5)
    ax2.set_xlabel('phase(deg)', fontsize=12)
    ax2.set_ylabel('RMS error', fontsize=12)
    ax2.set_xlim([0, 360])
    ax2.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.savefig(output_path, dpi=300, bbox_inches='tight')
    plt.close()
    print(f"  [OK] Saved: {output_path}")

    return output_path


def plot_f_error_hist_by_code(data, fin, num_bits=12, output_path='fig_f_error_hist_code.png'):
    """
    Figure (f): Error histogram by code

    Parameters:
        data: ADC output data
        fin: Normalized input frequency
        num_bits: ADC bit width
        output_path: Output file path
    """
    from adctoolbox.aout.tomDecomp import tomDecomp
    from adctoolbox.aout.errHistSine import errHistCode

    data = np.asarray(data).flatten()
    N = len(data)

    # Thompson decomposition to get error (order=50, consistent with run_all_tests.py)
    signal, error, indep, dep, phi = tomDecomp(data, fin, order=50, disp=0)

    # Scale data and call errHistCode (consistent with run_all_tests.py)
    # Note: errHistSine expects scaled data, use 2^12 as standard
    scaled = data * (2**12)
    emean_code, erms_code, code_bins, _ = errHistCode(
        scaled, bin_count=1000, fin=fin, disp=0
    )

    # Create plot (top and bottom parts)
    fig = plt.figure(figsize=(10, 8))

    # Top part: error scatter and mean
    ax1 = plt.subplot(2, 1, 1)
    ax1.plot(data, error, 'r.', markersize=0.5, alpha=0.3, label='error')
    ax1.plot(code_bins, emean_code, 'b-', linewidth=2, label='mean')
    ax1.set_ylabel('error', fontsize=12)
    ax1.set_title('(f) Error histogram by code', fontsize=14, fontweight='bold', loc='left')
    ax1.set_xlim([np.min(data), np.max(data)])
    ax1.legend(fontsize=10)
    ax1.grid(True, alpha=0.3)

    # Bottom part: RMS bar chart
    ax2 = plt.subplot(2, 1, 2)
    bin_width = (code_bins[1] - code_bins[0]) if len(code_bins) > 1 else 1
    ax2.bar(code_bins, erms_code, width=bin_width*0.8, color='skyblue', edgecolor='navy', linewidth=0.5)
    ax2.set_xlabel('code', fontsize=12)
    ax2.set_ylabel('RMS error', fontsize=12)
    ax2.set_xlim([np.min(data), np.max(data)])
    ax2.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.savefig(output_path, dpi=300, bbox_inches='tight')
    plt.close()
    print(f"  [OK] Saved: {output_path}")

    return output_path


def generate_multimodal_report(data, fs=1e6, num_bits=12, fin=None, output_dir='.'):
    """
    Generate Multi-Modal Error Signatures Report (6 individual figures)

    Parameters:
        data: ADC output data (1D numpy array)
        fs: Sampling frequency (Hz)
        num_bits: ADC bit width
        fin: Normalized input frequency f_in/f_sample (auto-detect if None)
        output_dir: Output directory

    Returns:
        output_paths: List of paths to 6 generated figures
    """

    # Ensure data is 1D array
    data = np.asarray(data).flatten()
    N = len(data)

    print(f"Generating Multi-Modal Error Signatures Report (6 individual figures)...")
    print(f"  Data points: {N}")
    print(f"  Sample rate: {fs/1e6:.2f} MHz")
    print(f"  ADC bits: {num_bits}")

    # Auto-detect frequency (using sine_fit + find_bin, consistent with run_all_tests.py)
    if fin is None or fin == 0:
        try:
            from adctoolbox.common.sineFit import sine_fit
            from adctoolbox.common.findBin import find_bin
            _, freq_est, _, _, _ = sine_fit(data)
            fin = find_bin(1, freq_est, N) / N
            print(f"  Detected frequency: {fin:.10f}")
        except ImportError:
            spec = np.abs(np.fft.fft(data))
            spec[0] = 0
            bin_max = np.argmax(spec[:N//2])
            fin = bin_max / N
            print(f"  Estimated frequency: {fin:.10f}")

    # Ensure output directory exists
    os.makedirs(output_dir, exist_ok=True)

    # Generate 6 figures
    output_paths = []

    print("\n[1/6] Generating (a) Spectrum...")
    path_a = os.path.join(output_dir, 'fig_a_spectrum.png')
    plot_a_spectrum(data, fs, num_bits, fin, path_a)
    output_paths.append(path_a)

    print("\n[2/6] Generating (b) Time-domain error...")
    path_b = os.path.join(output_dir, 'fig_b_time_error.png')
    plot_b_time_domain_error(data, fin, num_bits, path_b)
    output_paths.append(path_b)

    print("\n[3/6] Generating (c) Phase-domain error...")
    path_c = os.path.join(output_dir, 'fig_c_phase_error.png')
    plot_c_phase_domain_error(data, fin, num_bits, path_c)
    output_paths.append(path_c)

    print("\n[4/6] Generating (d) Code overflow...")
    path_d = os.path.join(output_dir, 'fig_d_code_overflow.png')
    plot_d_code_overflow(data, num_bits, path_d)
    output_paths.append(path_d)

    print("\n[5/6] Generating (e) Error histogram by phase...")
    path_e = os.path.join(output_dir, 'fig_e_error_hist_phase.png')
    plot_e_error_hist_by_phase(data, fin, num_bits, path_e)
    output_paths.append(path_e)

    print("\n[6/6] Generating (f) Error histogram by code...")
    path_f = os.path.join(output_dir, 'fig_f_error_hist_code.png')
    plot_f_error_hist_by_code(data, fin, num_bits, path_f)
    output_paths.append(path_f)

    print(f"\n[OK] All 6 figures generated successfully!")
    print(f"Output directory: {output_dir}")

    return output_paths


if __name__ == "__main__":
    print("=" * 70)
    print("Multi-Modal Error Signatures Report Generator - Test")
    print("=" * 70)

    # Generate test data
    N = 4096
    fs = 1e6
    fin_hz = 28320.3125
    re_fin = fin_hz / fs
    num_bits = 12

    t = np.arange(N) / fs

    # Ideal signal
    signal_ideal = np.sin(2 * np.pi * fin_hz * t) * (2**(num_bits-1) - 100) + 2**(num_bits-1)

    # Add 3rd harmonic distortion
    phase = (2 * np.pi * fin_hz * t) % (2 * np.pi)
    harmonic_3rd = 50 * np.sin(3 * phase)

    # Add noise
    noise = 10 * np.random.randn(N)

    # Synthesize ADC output
    adc_output = signal_ideal + harmonic_3rd + noise
    adc_output = np.clip(adc_output, 0, 2**num_bits - 1)

    print(f"\nTest parameters:")
    print(f"  Sample count: {N}")
    print(f"  Sampling frequency: {fs/1e6:.2f} MHz")
    print(f"  Input frequency: {fin_hz/1e3:.2f} kHz")
    print(f"  ADC bits: {num_bits}")

    # Generate report
    output_dir = os.path.join(os.path.dirname(__file__), "..", "output_data", "test_multimodal")

    generate_multimodal_report(
        adc_output,
        fs=fs,
        num_bits=num_bits,
        fin=re_fin,
        output_dir=output_dir
    )

    print("\n[OK] Test completed!")
    print("=" * 70)
