from abc import ABC, abstractmethod
from io import BytesIO
import tarfile
from tarfile import TarInfo
from typing import Any, Callable, Dict, IO, Tuple, Type, Union

from pydantic import ValidationError

from omnipy.data.dataset import Dataset


class Serializer(ABC):
    @classmethod
    def get_supported_dataset_type(cls) -> Type[Dataset]:
        pass

    @classmethod
    @abstractmethod
    def serialize(cls, dataset: Dataset) -> Union[bytes, memoryview]:
        pass

    @classmethod
    @abstractmethod
    def deserialize(cls, serialized: bytes) -> Dataset:
        pass


class TarFileSerializer(Serializer, ABC):
    @classmethod
    def create_tarfile_from_dataset(cls,
                                    dataset: Dataset,
                                    file_suffix: str,
                                    data_encode_func: Callable[[Any], Union[bytes, memoryview]]):
        bytes_io = BytesIO()
        with tarfile.open(fileobj=bytes_io, mode='w:gz') as tarfile_stream:
            for obj_type, data_obj in dataset.items():
                json_data_bytestream = BytesIO(data_encode_func(data_obj))
                json_data_bytestream.seek(0)
                tarinfo = TarInfo(name=f'{obj_type}.{file_suffix}')
                tarinfo.size = len(json_data_bytestream.getbuffer())
                tarfile_stream.addfile(tarinfo, json_data_bytestream)
        return bytes_io.getbuffer().tobytes()

    @classmethod
    def create_dataset_from_tarfile(cls,
                                    dataset: Dataset,
                                    tarfile_bytes: bytes,
                                    file_suffix: str,
                                    data_decode_func: Callable[[IO[bytes]], Any],
                                    dictify_object_func: Callable[[str, Any], Union[Dict, str]],
                                    import_method='from_data'):
        with tarfile.open(fileobj=BytesIO(tarfile_bytes), mode='r:gz') as tarfile_stream:
            for filename in tarfile_stream.getnames():
                obj_type_file = tarfile_stream.extractfile(filename)
                assert filename.endswith(f'.{file_suffix}')
                obj_type = '.'.join(filename.split('.')[:-1])
                getattr(dataset, import_method)(
                    dictify_object_func(obj_type, data_decode_func(obj_type_file)))


class SerializerRegistry:
    def __init__(self) -> None:
        self._serializer_classes: list[Type[Serializer]] = []

    def register(self, serializer_cls: Type[Serializer]) -> None:
        self._serializer_classes.append(serializer_cls)

    @property
    def serializers(self) -> Tuple[Type[Serializer], ...]:
        return tuple(self._serializer_classes)

    @property
    def tar_file_serializers(self) -> Tuple[Type[TarFileSerializer], ...]:
        return tuple(cls for cls in self._serializer_classes if issubclass(cls, TarFileSerializer))

    def auto_detect(self, dataset: Dataset):
        return self._autodetect_serializer(dataset, self.serializers)

    def auto_detect_tar_file_serializer(self, dataset: Dataset):
        return self._autodetect_serializer(dataset, self.tar_file_serializers)

    @classmethod
    def _autodetect_serializer(cls, dataset, serializers):
        # def _direct(dataset, new_dataset_cls):
        #     new_dataset = new_dataset_cls(dataset)
        #     return new_dataset

        def _to_data_from_json(dataset, new_dataset_cls):
            new_dataset = new_dataset_cls()
            new_dataset.from_json(dataset.to_data())
            return new_dataset

        def _to_data_from_data(dataset, new_dataset_cls):
            new_dataset = new_dataset_cls()
            new_dataset.from_data(dataset.to_data())
            return new_dataset

        # def _to_json_from_json(dataset, new_dataset_cls):
        #     new_dataset = new_dataset_cls()
        #     new_dataset.from_json(dataset.to_json())
        #     return new_dataset

        for func in (_to_data_from_json, _to_data_from_data):
            for serializer in serializers:
                new_dataset_cls = serializer.get_supported_dataset_type()

                try:
                    new_dataset = func(dataset, new_dataset_cls)
                    return new_dataset, serializer
                except (TypeError, ValueError, ValidationError) as e:
                    pass

        return None, None
