"""
Module that implements 4 related concepts:
  - TargetDuration:    how long an algorithm or part of an algorithm should run
  - ProgressTracker:   generated by a TargetDuration when we ask to start tracking this duration
  - Progress:          a snapshot of progress made so far, as reported by a ProgressTracker, which
                         can also be used to update a tqdm progress bar.
  - Elapsed:           represents observed duration in terms of absolute time & iterations and can be
                         received from a ProgressTracker object.
"""

from __future__ import annotations

import time
from abc import ABC, abstractmethod
from dataclasses import dataclass

from tqdm import tqdm


# =================================================================================================
#  TargetDuration
# =================================================================================================
class TargetDuration(ABC):
    @abstractmethod
    def track(self) -> ProgressTracker:
        raise NotImplementedError()

    # -------------------------------------------------------------------------
    #  Factory methods
    # -------------------------------------------------------------------------
    @classmethod
    def seconds(cls, t_target_sec: float) -> TargetDuration:
        return _TargetTimeDuration(t_target_sec)

    @classmethod
    def minutes(cls, t_target_min: float) -> TargetDuration:
        return _TargetTimeDuration(t_target_min * 60.0)

    @classmethod
    def hours(cls, t_target_hours: float) -> TargetDuration:
        return _TargetTimeDuration(t_target_hours * 3600.0)

    @classmethod
    def iterations(cls, n_iters: int) -> TargetDuration:
        return _TargetIterationCount(n_iters)


class _TargetTimeDuration(TargetDuration):
    def __init__(self, t_target_sec: float):
        if t_target_sec <= 0:
            raise ValueError("t_target_sec must be > 0")
        self._t_target_sec = t_target_sec

    def __str__(self):
        return repr(self)

    def __repr__(self):
        if self._t_target_sec <= 1.0:
            return f"TargetDuration({self._t_target_sec:.3f} seconds)"
        elif self._t_target_sec < 10.0:
            return f"TargetDuration({self._t_target_sec:.2f} seconds)"
        elif self._t_target_sec < 100.0:
            return f"TargetDuration({self._t_target_sec:.1f} seconds)"
        else:
            return f"TargetDuration({int(round(self._t_target_sec)):_} seconds)"

    def track(self) -> ProgressTracker:
        return _TimeTracker(self._t_target_sec)


class _TargetIterationCount(TargetDuration):
    def __init__(self, n_iters: int):
        if n_iters <= 0:
            raise ValueError("n_iters must be > 0")
        self._n_iters = n_iters

    def __str__(self):
        return repr(self)

    def __repr__(self):
        return f"TargetDuration({self._n_iters:_} iterations)"

    def track(self) -> ProgressTracker:
        return _IterationTracker(self._n_iters)


# --- shorthand factory methods ---------------------------
iterations = TargetDuration.iterations
seconds = TargetDuration.seconds
minutes = TargetDuration.minutes
hours = TargetDuration.hours


# =================================================================================================
#  ProgressTracker
# =================================================================================================
class ProgressTracker(ABC):
    """Class that tracks algorithm progress in iterations / time / ... and starts tracking upon construction"""

    def __init__(self):
        self._t_start = time.perf_counter()
        self._iter_count = 0

    def report_iterations_done(self, n: int):
        self._iter_count += n

    def iter_count(self) -> int:
        return self._iter_count

    def elapsed(self) -> Elapsed:
        return Elapsed(
            t_elapsed_sec=time.perf_counter() - self._t_start,
            n_iterations=self._iter_count,
        )

    def iters_per_second(self):
        t_elapsed = time.perf_counter() - self._t_start
        if (t_elapsed > 0.0) and (self._iter_count > 0):
            return self._iter_count / t_elapsed
        else:
            return 0.0

    @abstractmethod
    def get_progress(self) -> Progress:
        raise NotImplementedError()


class _TimeTracker(ProgressTracker):
    def __init__(self, max_seconds: float):
        super().__init__()
        self._max_seconds = max_seconds
        self._n_total = max(1, int(max_seconds))

    def get_progress(self) -> Progress:
        est_iters_per_second = self.iters_per_second()
        t_elapsed = time.perf_counter() - self._t_start
        if t_elapsed >= self._max_seconds:
            fraction = 1.0
            est_n_iters_remaining = 0
        else:
            fraction = t_elapsed / self._max_seconds
            est_n_iters_remaining = max(1, int((self._max_seconds - t_elapsed) * est_iters_per_second))

        return Progress(
            tqdm_n_total=self._n_total,
            fraction=fraction,
            iter_count=self._iter_count,
            est_n_iters_remaining=est_n_iters_remaining,
            est_iters_per_second=est_iters_per_second,
        )


class _IterationTracker(ProgressTracker):
    def __init__(self, max_iters: int):
        super().__init__()
        self._max_iters = max_iters

    def get_progress(self) -> Progress:
        if self._iter_count >= self._max_iters:
            fraction = 1.0
        else:
            fraction = self._iter_count / self._max_iters

        return Progress(
            tqdm_n_total=self._max_iters,
            fraction=fraction,
            iter_count=self._iter_count,
            est_n_iters_remaining=max(0, self._max_iters - self._iter_count),
            est_iters_per_second=self.iters_per_second(),
        )


# =================================================================================================
#  Progress
# =================================================================================================
@dataclass(frozen=True, slots=True)
class Progress:
    """Representation of progress made so far towards a target duration, including various meta-data."""

    # --- tqdm ---
    tqdm_n_total: int  # total number of steps for tqdm progress bar

    # --- other---
    fraction: float  # fractional progress in [0,1]
    iter_count: int  # total number of iterations reported done
    est_n_iters_remaining: int  # estimated number of iterations remaining
    est_iters_per_second: float  # estimated number of iterations executed so far per second

    @property
    def tqdm_n_current(self) -> int:
        if self.fraction >= 1.0:
            return self.tqdm_n_total
        else:
            # report fractional progress, but limited to tqdm_n_total-1 in case we're not finished yet.
            return max(0, min(self.tqdm_n_total - 1, round(self.fraction * self.tqdm_n_total)))

    @property
    def est_progress_fraction_per_iter(self) -> float:
        """Estimated progress fraction increase per executed iteration."""
        est_total_iters = self.iter_count + self.est_n_iters_remaining
        # normally either iter_count or est_n_iters_remaining is >=1; just to be sure, we take max(1, ...)
        return 1.0 / max(1, est_total_iters)

    @property
    def is_finished(self) -> bool:
        return self.fraction >= 1.0

    def update_tqdm(self, pbar: tqdm):
        """Updates a tqdm progress bar to reflect current progress"""
        pbar.n = self.tqdm_n_current
        pbar.total = self.tqdm_n_total
        pbar.refresh()


# =================================================================================================
#  Elapsed
# =================================================================================================
@dataclass(frozen=True, slots=True)
class Elapsed:
    # --- data fields -----------------
    t_elapsed_sec: float
    n_iterations: int

    # --- equal -----------------------
    def __eq__(self, other) -> bool:
        # equal if...
        #   n_iterations is exactly equal
        #   t_elapsed_sec is equal within 1e-10 sec, which is < 1 clock cycle on typical modern hardware
        return (
            isinstance(other, Elapsed)
            and (self.n_iterations == other.n_iterations)
            and abs(self.t_elapsed_sec - other.t_elapsed_sec) < 1e-10
        )

    # --- math ------------------------
    def __add__(self, other: Elapsed) -> Elapsed:
        if other == 0:
            return self  # helps ensure sum() works correctly
        elif not isinstance(other, Elapsed):
            return NotImplemented
        return Elapsed(
            t_elapsed_sec=self.t_elapsed_sec + other.t_elapsed_sec,
            n_iterations=self.n_iterations + other.n_iterations,
        )

    def __radd__(self, other: Elapsed | int) -> Elapsed:
        return self + other


__ALL__ = [
    "TargetDuration",
    "ProgressTracker",
    "Progress",
    "Elapsed",
    "iterations",
    "seconds",
    "minutes",
    "hours",
]
