"""
Module that implements 4 related concepts:
  - TargetDuration:    how long an algorithm or part of an algorithm should run
  - ProgressTracker:   generated by a TargetDuration when we ask to start tracking this duration
  - Progress:          a snapshot of progress made so far, as reported by a ProgressTracker, which
                         can also be used to update a tqdm progress bar.
  - Elapsed:           represents observed duration in terms of absolute time & iterations and can be
                         received from a ProgressTracker object.
"""

from __future__ import annotations

import time
from abc import ABC, abstractmethod
from dataclasses import dataclass

from tqdm import tqdm


# =================================================================================================
#  TargetDuration
# =================================================================================================
class TargetDuration(ABC):
    @abstractmethod
    def track(self) -> ProgressTracker:
        raise NotImplementedError()

    # -------------------------------------------------------------------------
    #  Factory methods
    # -------------------------------------------------------------------------
    @classmethod
    def seconds(cls, t_target_sec: float) -> TargetDuration:
        return _TargetTimeDuration(t_target_sec)

    @classmethod
    def minutes(cls, t_target_min: float) -> TargetDuration:
        return _TargetTimeDuration(t_target_min * 60.0)

    @classmethod
    def hours(cls, t_target_hours: float) -> TargetDuration:
        return _TargetTimeDuration(t_target_hours * 3600.0)

    @classmethod
    def iterations(cls, n_iters: int) -> TargetDuration:
        return _TargetIterationCount(n_iters)


class _TargetTimeDuration(TargetDuration):
    def __init__(self, t_target_sec: float):
        if t_target_sec <= 0:
            raise ValueError("t_target_sec must be > 0")
        self._t_target_sec = t_target_sec

    def __str__(self):
        return repr(self)

    def __repr__(self):
        if self._t_target_sec <= 1.0:
            return f"TargetDuration({self._t_target_sec:.3f} seconds)"
        elif self._t_target_sec < 10.0:
            return f"TargetDuration({self._t_target_sec:.2f} seconds)"
        elif self._t_target_sec < 100.0:
            return f"TargetDuration({self._t_target_sec:.1f} seconds)"
        else:
            return f"TargetDuration({int(round(self._t_target_sec)):_} seconds)"

    def track(self) -> ProgressTracker:
        return _TimeTracker(self._t_target_sec)


class _TargetIterationCount(TargetDuration):
    def __init__(self, n_iters: int):
        if n_iters <= 0:
            raise ValueError("n_iters must be > 0")
        self._n_iters = n_iters

    def __str__(self):
        return repr(self)

    def __repr__(self):
        return f"TargetDuration({self._n_iters:_} iterations)"

    def track(self) -> ProgressTracker:
        return _IterationTracker(self._n_iters)


# --- shorthand factory methods ---------------------------
iterations = TargetDuration.iterations
seconds = TargetDuration.seconds
minutes = TargetDuration.minutes
hours = TargetDuration.hours


# =================================================================================================
#  ProgressTracker
# =================================================================================================
class ProgressTracker(ABC):
    """Class that tracks algorithm progress in iterations / time / ... and starts tracking upon construction"""

    def __init__(self):
        self._t_start = time.perf_counter()
        self._iter_count = 0

    def iterations_done(self, n: int):
        self._iter_count += n

    def iter_count(self) -> int:
        return self._iter_count

    def elapsed(self) -> Elapsed:
        return Elapsed(
            t_elapsed_sec=time.perf_counter() - self._t_start,
            n_iterations=self._iter_count,
        )

    def iters_per_second(self):
        t_elapsed = time.perf_counter() - self._t_start
        if t_elapsed > 0.0:
            return self._iter_count / t_elapsed
        else:
            return 1.0

    @abstractmethod
    def get_progress(self) -> Progress:
        raise NotImplementedError()

    @abstractmethod
    def estimated_n_iterations_remaining(self) -> int:
        raise NotImplementedError()


class _TimeTracker(ProgressTracker):
    def __init__(self, max_seconds: float):
        super().__init__()
        self._max_seconds = max_seconds
        self._n_total = max(1, int(max_seconds))

    def get_progress(self) -> Progress:
        t_elapsed = time.perf_counter() - self._t_start
        if t_elapsed >= self._max_seconds:
            n_current = self._n_total
        else:
            n_current = min(int(t_elapsed), self._n_total - 1)

        return Progress(
            n_current=n_current,
            n_total=self._n_total,
        )

    def estimated_n_iterations_remaining(self) -> int:
        t_elapsed = time.perf_counter() - self._t_start
        if t_elapsed >= self._max_seconds:
            return 0
        else:
            iters_per_sec = self._iter_count / t_elapsed
            iters_remaining = iters_per_sec * (self._max_seconds - t_elapsed)
            return int(iters_remaining)


class _IterationTracker(ProgressTracker):
    def __init__(self, max_iters: int):
        super().__init__()
        self._max_iters = max_iters

    def get_progress(self) -> Progress:
        return Progress(
            n_current=min(self._iter_count, self._max_iters),
            n_total=self._max_iters,
        )

    def estimated_n_iterations_remaining(self) -> int:
        return max(0, self._max_iters - self._iter_count)


# =================================================================================================
#  Progress
# =================================================================================================
@dataclass(frozen=True, slots=True)
class Progress:
    n_current: int
    n_total: int

    @property
    def is_finished(self) -> bool:
        return self.n_current >= self.n_total

    def update_tqdm(self, pbar: tqdm):
        """Updates a tqdm progress bar to reflect current progress"""
        pbar.n = self.n_current
        pbar.total = self.n_total
        pbar.refresh()


# =================================================================================================
#  Elapsed
# =================================================================================================
@dataclass(frozen=True, slots=True)
class Elapsed:
    t_elapsed_sec: float
    n_iterations: int

    def __add__(self, other: Elapsed) -> Elapsed:
        if other == 0:
            return self  # helps ensure sum() works correctly
        elif not isinstance(other, Elapsed):
            return NotImplemented
        return Elapsed(
            t_elapsed_sec=self.t_elapsed_sec + other.t_elapsed_sec,
            n_iterations=self.n_iterations + other.n_iterations,
        )

    def __radd__(self, other: Elapsed | int) -> Elapsed:
        return self + other


__ALL__ = [
    "TargetDuration",
    "ProgressTracker",
    "Progress",
    "Elapsed",
    "iterations",
    "seconds",
    "minutes",
    "hours",
]
