import json
from typing import Mapping, Any

import grpc

from sapiopycommons.ai.api.fielddefinitions.proto.fields_pb2 import FieldValuePbo
from sapiopycommons.ai.api.plan.tool.proto.entry_pb2 import DataTypePbo, StepBinaryContainerPbo, StepCsvRowPbo, \
    StepCsvHeaderRowPbo, StepCsvContainerPbo, StepJsonContainerPbo, StepImageContainerPbo, StepTextContainerPbo, \
    StepItemContainerPbo, StepInputBatchPbo
from sapiopycommons.ai.api.plan.tool.proto.tool_pb2 import ProcessStepResponsePbo, ProcessStepRequestPbo, \
    ToolDetailsRequestPbo, ToolDetailsResponsePbo
from sapiopycommons.ai.api.plan.tool.proto.tool_pb2_grpc import ToolServiceStub
from sapiopycommons.ai.api.session.proto.sapio_conn_info_pb2 import SapioConnectionInfoPbo, SapioUserSecretTypePbo


class ToolOutput:
    """
    A class for holding the output of a TestClient that calls a ToolService. ToolOutput objects an be
    printed to show the output of the tool in a human-readable format.
    """
    tool_name: str

    binary_output: list[bytes]
    csv_output: list[dict[str, Any]]
    json_output: list[Any]
    image_output: list[bytes]
    text_output: list[str]

    new_records: list[Mapping[str, FieldValuePbo]]

    logs: list[str]

    def __init__(self, tool_name: str):
        self.tool_name = tool_name
        self.binary_output = []
        self.csv_output = []
        self.json_output = []
        self.image_output = []
        self.text_output = []
        self.new_records = []
        self.logs = []

    def __str__(self):
        ret_val: str = f"{self.tool_name} Output:\n"
        ret_val += "-" * 25 + "\n"
        ret_val += f"Binary Output: {len(self.binary_output)} item(s)\n"
        for binary in self.binary_output:
            ret_val += f"\t{len(binary)} byte(s)\n"
            ret_val += f"\t{binary[:50]}...\n"
        ret_val += f"CSV Output: {len(self.csv_output)} item(s)\n"
        if self.csv_output:
            ret_val += f"\tHeaders: {', '.join(self.csv_output[0].keys())}\n"
            for i, csv_row in enumerate(self.csv_output):
                ret_val += f"\t{i}: {', '.join(f'{v}' for k, v in csv_row.items())}\n"
        ret_val += f"JSON Output: {len(self.json_output)} item(s)\n"
        if self.json_output:
            ret_val += f"\t{json.dumps(self.json_output, indent=2)}\n"
        ret_val += f"Image Output: {len(self.image_output)} item(s)\n"
        for image in self.image_output:
            ret_val += f"\t{len(image)} bytes\n"
            ret_val += f"\t{image[:50]}...\n"
        ret_val += f"Text Output: {len(self.text_output)} item(s)\n"
        for text in self.text_output:
            ret_val += f"\t{text}...\n"
        ret_val += f"New Records: {len(self.new_records)} item(s)\n"
        for record in self.new_records:
            ret_val += f"\t{json.dumps(record, indent=2)}\n"
        ret_val += f"Logs: {len(self.logs)} item(s)\n"
        for log in self.logs:
            ret_val += f"\t{log}\n"
        return ret_val


class TestClient:
    """
    A client for testing a ToolService. This client can be used to send requests to a tool and receive
    responses.
    """
    server_url: str
    connection: SapioConnectionInfoPbo
    _request_inputs: list[Any]
    _config_fields: dict[str, Any]

    def __init__(self, server_url: str):
        """
        :param server_url: The URL of the gRPC server to connect to.
        """
        self.create_user()
        self.server_url = server_url
        self._request_inputs = []
        self._config_fields = {}

    def create_user(self):
        """
        Create a SapioConnectionInfoPbo object with test credentials. This method can be overridden to
        create a user with specific credentials for testing.
        """
        self.connection = SapioConnectionInfoPbo()
        self.connection.username = "Testing"
        self.connection.webservice_url = "https://localhost:8080/webservice/api"
        self.connection.app_guid = "1234567890"
        self.connection.secret_type = SapioUserSecretTypePbo.PASSWORD
        self.connection.rmi_host.append("Testing")
        self.connection.rmi_port = 9001
        self.connection.secret = "password"

    def add_input_input(self, input_data: list[bytes]) -> None:
        """
        Add a binary input to the the next request.
        """
        self._add_input(DataTypePbo.BINARY, StepBinaryContainerPbo(items=input_data))

    def add_csv_input(self, input_data: list[dict[str, Any]]) -> None:
        """
        Add a CSV input to the next request.
        """
        csv_items = []
        for row in input_data:
            csv_items.append(StepCsvRowPbo(cells=[str(value) for value in row.values()]))
        header = StepCsvHeaderRowPbo(cells=list(input_data[0].keys()))
        self._add_input(DataTypePbo.CSV, StepCsvContainerPbo(header=header, items=csv_items))

    def add_json_input(self, input_data: list[dict[str, Any]]) -> None:
        """
        Add a JSON input to the next request.
        """
        self._add_input(DataTypePbo.JSON, StepJsonContainerPbo(items=[json.dumps(x) for x in input_data]))

    def add_image_input(self, input_data: list[bytes], image_format: str = "png") -> None:
        """
        Add an image input to the next request.
        """
        self._add_input(DataTypePbo.IMAGE, StepImageContainerPbo(items=input_data, image_format=image_format))

    def add_text_input(self, input_data: list[str]) -> None:
        """
        Add a text input to the next request.
        """
        self._add_input(DataTypePbo.TEXT, StepTextContainerPbo(items=input_data))

    def clear_inputs(self) -> None:
        """
        Clear all inputs that have been added to the next request.
        This is useful if you want to start a new request without the previous inputs.
        """
        self._request_inputs.clear()

    def add_config_field(self, field_name: str, value: Any) -> None:
        """
        Add a configuration field value to the next request.

        :param field_name: The name of the configuration field.
        :param value: The value to set for the configuration field.
        """
        self._config_fields[field_name] = value

    def add_config_fields(self, config_fields: dict[str, Any]) -> None:
        """
        Add multiple configuration field values to the next request.

        :param config_fields: A dictionary of configuration field names and their corresponding values.
        """
        self._config_fields.update(config_fields)

    def clear_configs(self) -> None:
        """
        Clear all configuration field values that have been added to the next request.
        This is useful if you want to start a new request without the previous configurations.
        """
        self._config_fields.clear()

    def clear_request(self) -> None:
        """
        Clear all inputs and configuration fields that have been added to the next request.
        This is useful if you want to start a new request without the previous inputs and configurations.
        """
        self.clear_inputs()
        self.clear_configs()

    def _add_input(self, data_type: DataTypePbo, items: Any) -> None:
        """
        Helper method for adding inputs to the next request.
        """
        match data_type:
            case DataTypePbo.BINARY:
                container = StepItemContainerPbo(dataType=data_type, binary_container=items)
            case DataTypePbo.CSV:
                container = StepItemContainerPbo(dataType=data_type, csv_container=items)
            case DataTypePbo.JSON:
                container = StepItemContainerPbo(dataType=data_type, json_container=items)
            case DataTypePbo.IMAGE:
                container = StepItemContainerPbo(dataType=data_type, image_container=items)
            case DataTypePbo.TEXT:
                container = StepItemContainerPbo(dataType=data_type, text_container=items)
            case _:
                raise ValueError(f"Unsupported data type: {data_type}")
        self._request_inputs.append(container)

    def get_service_details(self) -> ToolDetailsResponsePbo:
        """
        Get the details of the tools from the server.

        :return: A ToolDetailsResponsePbo object containing the details of the tool service.
        """
        with grpc.insecure_channel(self.server_url) as channel:
            stub = ToolServiceStub(channel)
            return stub.GetToolDetails(ToolDetailsRequestPbo(sapio_conn_info=self.connection))

    def call_tool(self, tool_name: str) -> ToolOutput:
        """
        Send the request to the tool service for a particular tool name. This will send all the inputs that have been
        added using the add_X_input functions.

        :param tool_name: The name of the tool to call on the server.
        :return: A ToolOutput object containing the results of the tool service call.
        """
        with grpc.insecure_channel(self.server_url) as channel:
            stub = ToolServiceStub(channel)

            response: ProcessStepResponsePbo = stub.ProcessData(
                ProcessStepRequestPbo(
                    sapio_user=self.connection,
                    tool_name=tool_name,
                    config_field_values=self._config_fields,
                    input=[
                        StepInputBatchPbo(is_partial=False, item_container=item)
                        for item in self._request_inputs
                    ]
                )
            )

            results = ToolOutput(tool_name)
            for item in response.output:
                container = item.item_container

                results.binary_output.extend(container.binary_container.items)
                for header in container.csv_container.header.cells:
                    output_row: dict[str, Any] = {}
                    for i, row in enumerate(container.csv_container.items):
                        output_row[header] = row.cells[i]
                    results.csv_output.append(output_row)
                results.json_output.extend([json.loads(x) for x in container.json_container.items])
                results.image_output.extend(container.image_container.items)
                results.text_output.extend(container.text_container.items)

            for record in response.new_records:
                results.new_records.append(record.fields)

            results.logs.extend(response.log)

            return results
