import os
import json
import random

import lazyllm
from lazyllm import launchers, LazyLLMCMD, ArgsDict, LOG
from .base import LazyLLMDeployBase, verify_fastapi_func

lazyllm.config.add("default_embedding_engine", str, "", "DEFAULT_EMBEDDING_ENGINE")

class Infinity(LazyLLMDeployBase):
    """This class is a subclass of ``LazyLLMDeployBase``, providing high-performance text-embeddings, reranking, and CLIP capabilities based on the [Infinity](https://github.com/michaelfeil/infinity) framework.

Args:
    launcher (lazyllm.launcher): The launcher for Infinity, defaulting to ``launchers.remote(ngpus=1)``.
    kw: Keyword arguments for updating default training parameters. Note that no additional keyword arguments can be passed here except those listed below.

The keyword arguments and their default values for this class are as follows:

Keyword Args: 
    host (str): The IP address of the service, defaulting to ``0.0.0.0``.
    port (int): The port number of the service, defaulting to ``None``, in which case LazyLLM will automatically generate a random port number.
    batch-size (int): The maximum batch size, defaulting to ``256``.


Examples:
    >>> import lazyllm
    >>> from lazyllm import deploy
    >>> deploy.Infinity()
    <lazyllm.llm.deploy type=Infinity>
    """
    keys_name_handle = {
        'inputs': 'input',
    }
    message_format = {
        'input': 'who are you ?',
    }
    default_headers = {'Content-Type': 'application/json'}

    def __init__(self,
                 launcher=launchers.remote(ngpus=1),
                 model_type='embed',
                 **kw,
                 ):
        super().__init__(launcher=launcher)
        self.kw = ArgsDict({
            'host': '0.0.0.0',
            'port': None,
            'batch-size': 256,
        })
        self._model_type = model_type
        self.kw.check_and_update(kw)
        self.random_port = False if 'port' in kw and kw['port'] else True
        if self._model_type == "reranker":
            self._update_reranker_message()

    def _update_reranker_message(self):
        self.keys_name_handle = {
            'inputs': 'query',
        }
        self.message_format = {
            'query': 'who are you ?',
            'documents': ['string'],
            'return_documents': False,
            'raw_scores': False,
            'top_n': 1,
            'model': 'default/not-specified',
        }
        self.default_headers = {'Content-Type': 'application/json'}

    def cmd(self, finetuned_model=None, base_model=None):
        if not os.path.exists(finetuned_model) or \
            not any(filename.endswith('.bin') or filename.endswith('.safetensors')
                    for filename in os.listdir(finetuned_model)):
            if not finetuned_model:
                LOG.warning(f"Note! That finetuned_model({finetuned_model}) is an invalid path, "
                            f"base_model({base_model}) will be used")
            finetuned_model = base_model

        def impl():
            if self.random_port:
                self.kw['port'] = random.randint(30000, 40000)
            cmd = f'infinity_emb v2 --model-id {finetuned_model} '
            cmd += self.kw.parse_kwargs()
            return cmd

        return LazyLLMCMD(cmd=impl, return_value=self.geturl, checkf=verify_fastapi_func)

    def geturl(self, job=None):
        if job is None:
            job = self.job
        if self._model_type == "reranker":
            target_name = 'rerank'
        else:
            target_name = 'embeddings'
        if lazyllm.config['mode'] == lazyllm.Mode.Display:
            return f'http://<ip>:<port>/{target_name}'
        else:
            return f'http://{job.get_jobip()}:{self.kw["port"]}/{target_name}'

    @staticmethod
    def extract_result(x, inputs):
        try:
            res_object = json.loads(x)
        except Exception as e:
            LOG.warning(f'JSONDecodeError on load {x}')
            raise e
        assert 'object' in res_object
        object_type = res_object['object']
        if object_type == 'embedding':
            res_list = [item['embedding'] for item in res_object['data']]
            if len(res_list) == 1 and type(inputs['input']) is str:
                res_list = res_list[0]
            return json.dumps(res_list)
        elif object_type == 'rerank':
            return [x['index'] for x in res_object['results']]
