"""
Implement tests for the solver engine.

In this case, having the following function:

```
    f(x) = e^{-x^2}; [-1, 1]
```
We'll going to discretize the function into a set of variables, such as:

```
f(Var) = sum_i^N Var_i, where Var = e^{-x^2} and N = (1 - (-1)) / dx
```

For the Objective Function, we'll have the area of the revolution solid
produced by the previous function. In this case, the OB is:

```
    OB(Var) = 2*pi* [sum_i^N Var_i * np.sqrt( 1 + (dVar_i / di)^2 ) ]
```
"""
from functools import partial
import math
import numpy as np
# Local imports
from pymath_compute.solvers.opt_solver import OptSolver
from pymath_compute.methods import OptMethods
from pymath_compute import Variable


def f(x: float) -> float:
    """Function of initial values to generate the lower solid revolution
    
    This function is np.exp(-x**2)
    """
    return np.exp(- (x ** 2))


def approx_derivate(
    variables: list[float],
    index: int,
    dx: float
) -> float:
    """Calculate the approximation of the derivate"""
    if index == len(variables) - 1:
        next_var = variables[index - 1] - (4*dx)/math.e
        prev_var = variables[index - 1]
    elif index == 0:
        next_var = variables[index + 1]
        prev_var = variables[index + 1] + (4*dx)/math.e
    else:
        next_var = variables[index + 1]
        prev_var = variables[index - 1]
    return (next_var - prev_var) / (2 * dx)


def objective_function(
    variables: dict[str, float],
    dx: float
) -> float:
    """Cost function. This represents the Area of the revolution solid
    generated by f(x).
    """
    values_of_vars = list(variables.values())
    # Get the sum of the values designed by the variables
    var_sum = sum(
        variable_value *
        np.sqrt(1 + approx_derivate(values_of_vars, i, dx) ** 2)
        for i, variable_value in enumerate(values_of_vars)
    )
    # Then, calculate the area
    return 2 * np.pi * var_sum


if __name__ == "__main__":

    # ========================================== #
    #           VARIABLES DEFINITION             #
    # ========================================== #

    # Define the array of variables
    f_variables: list[Variable] = []
    # Define a x step value, and then, generate the entire number of
    # variables Var based on the N obtained
    X0 = -1
    XF = 1

    X_STEP = 0.001
    N_VARS = int(XF - X0 / X_STEP)
    X_LINSPACE = np.linspace(X0, XF, N_VARS)
    for i in range(N_VARS):
        # We do not define lb and ub so we can get the
        # default boundaries, being lb=-inf and ub=inf
        variable = Variable(name=f"Var_{i+1}")
        # Then, based on the function f(x) = e^{-x^2},
        # define the initial value of this variable
        variable.value = f(X_LINSPACE[i])
        # Append this variable
        f_variables.append(variable)

    # ========================================== #
    #             SOLVER DEFINITION              #
    # ========================================== #
    # Instance the solver
    solver = OptSolver()
    solver.set_variables(variables=f_variables)
    solver.set_objective_function(
        partial(objective_function, dx=X_STEP)  # type: ignore
    )
    # Set the parameters
    solver.set_solver_config({
        "solver_time": 30,
        "solver_method": OptMethods.GRADIENT_DESCENT
    })
    # Run the solver!
    solver.solve()
    # Evaluate the status
    print(f"The status is {solver.status}")
    # Generate a expression from the sum of variables
    expr = f_variables[0].to_expression()
    for var in f_variables[1:]:
        expr = expr + var
    # Plot the solution
    expr.plot(
        title="$f(x)$ revolution solid values with " +
        f"A={objective_function(expr.terms, X_STEP)}",  # type: ignore
        xlabel="$Var_i$",
        ylabel="$Var$ value",
        figsize=(10, 4),
        # store_as_pdf=True
    )
