from .huggingface_utils import get_auth_token
from .ort_settings import get_onnx_runtime_sessions
from .onnx_exporter import (
    generate_onnx_representation,
    quantize,
    get_model_paths,
    saved_models_path,
)
from pathlib import Path

from transformers import (
    AutoConfig,
    MT5Config,
    T5ForConditionalGeneration,
)
from transformers.modeling_outputs import (
    Seq2SeqLMOutput,
    BaseModelOutput,
)
import torch
import functools
import operator
import numpy


class T5Encoder(torch.nn.Module):
    def __init__(self, encoder_sess):
        super().__init__()
        self.encoder = encoder_sess
        self.main_input_name = "input_ids"

    def forward(
        self,
        input_ids,
        attention_mask,
        inputs_embeds=None,
        head_mask=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):

        encoder_hidden_state = torch.from_numpy(
            self.encoder.run(
                None,
                {
                    "input_ids": input_ids.cpu().numpy(),
                    "attention_mask": attention_mask.cpu().numpy(),
                },
            )[0]
        )

        return BaseModelOutput(encoder_hidden_state)


class T5DecoderInit(torch.nn.Module):
    def __init__(self, decoder_sess):
        super().__init__()
        self.decoder = decoder_sess

    def forward(self, input_ids, encoder_attention_mask, encoder_hidden_states):

        decoder_outputs = self.decoder.run(
            None,
            {
                "input_ids": input_ids.cpu().numpy(),
                "encoder_attention_mask": encoder_attention_mask.cpu().numpy(),
                "encoder_hidden_states": encoder_hidden_states.cpu().numpy(),
            },
        )

        list_pkv = tuple(torch.from_numpy(x) for x in decoder_outputs[1:])

        out_past_key_values = tuple(
            list_pkv[i: i + 4] for i in range(0, len(list_pkv), 4)
        )

        return torch.from_numpy(decoder_outputs[0]), out_past_key_values


class T5Decoder(torch.nn.Module):
    def __init__(self, decoder_sess):
        super().__init__()
        self.decoder = decoder_sess

    def forward(self, input_ids, attention_mask, encoder_output, past_key_values):

        decoder_inputs = {
            "input_ids": input_ids.cpu().numpy(),
            "encoder_attention_mask": attention_mask.cpu().numpy(),
            "encoder_hidden_states": encoder_output.cpu().numpy(),
        }

        flat_past_key_values = functools.reduce(
            operator.iconcat, past_key_values, [])

        past_key_values = {
            f"pkv_{i}": pkv.cpu().numpy() for i, pkv in enumerate(flat_past_key_values)
        }

        decoder_outputs = self.decoder.run(
            None, {**decoder_inputs, **past_key_values})
        # converts each value of the list to tensor from numpy
        list_pkv = tuple(torch.from_numpy(x) for x in decoder_outputs[1:])

        # creates a tuple of tuples of shape 6x4 from the above tuple
        out_past_key_values = tuple(
            list_pkv[i: i + 4] for i in range(0, len(list_pkv), 4)
        )

        return torch.from_numpy(decoder_outputs[0]), out_past_key_values


class OnnxT5(T5ForConditionalGeneration):
    """creates a T5 model using onnx sessions (encode, decoder & init_decoder)"""

    def __init__(self, model_or_model_path, onnx_model_sessions):
        config = AutoConfig.from_pretrained(
            model_or_model_path, use_auth_token=get_auth_token()
        )
        super().__init__(config)

        # monkeypatch to work for MT5
        if (
            isinstance(model_or_model_path, str)
            and "mt5" in model_or_model_path.lower()
        ) or (
            hasattr(model_or_model_path, "name_or_path")
            and "mt5" in model_or_model_path.name_or_path
        ):
            self.model_type = "mt5"
            self.config_class = MT5Config
            self._keys_to_ignore_on_load_missing = [
                r"encoder\.embed_tokens\.weight",
            ]
            self._keys_to_ignore_on_save = [
                r"encoder\.embed_tokens\.weight",
            ]

        assert len(onnx_model_sessions) == 3, "all three models should be given"

        encoder_sess, decoder_sess, decoder_sess_init = onnx_model_sessions

        self.encoder = T5Encoder(encoder_sess)
        self.decoder = T5Decoder(decoder_sess)
        self.decoder_init = T5DecoderInit(decoder_sess_init)

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        decoder_input_ids=None,
        decoder_attention_mask=None,
        head_mask=None,
        decoder_head_mask=None,
        cross_attn_head_mask=None,
        encoder_outputs=None,
        past_key_values=None,
        inputs_embeds=None,
        decoder_inputs_embeds=None,
        labels=None,
        use_cache=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):

        if encoder_outputs is None:
            # Convert encoder inputs in embeddings if needed
            encoder_outputs = self.encoder(
                input_ids=input_ids, attention_mask=attention_mask
            )

        encoder_hidden_states = encoder_outputs[0]

        if past_key_values is not None:
            if decoder_input_ids is not None:
                decoder_input_ids = decoder_input_ids[:, -1:]
            if decoder_inputs_embeds is not None:
                decoder_inputs_embeds = decoder_inputs_embeds[:, -1:]

        if past_key_values is None:

            # runs only for the first time:
            init_onnx_outputs = self.decoder_init(
                decoder_input_ids, attention_mask, encoder_hidden_states
            )

            logits, past_key_values = init_onnx_outputs

        else:

            onnx_outputs = self.decoder(
                decoder_input_ids,
                attention_mask,
                encoder_hidden_states,
                past_key_values,
            )

            logits, past_key_values = onnx_outputs

        return Seq2SeqLMOutput(logits=logits, past_key_values=past_key_values)


def export_and_get_onnx_model(
    model_or_model_path, custom_output_path=saved_models_path, quantized=True
):
    """
                          Method for whole pipeline,
    converts from pytorch to onnx --> quantizes model --> sets onnx runtime
                --> builds whole onnx model with all sessions

    """

    # Step 1. convert huggingfaces t5 model to onnx
    onnx_model_paths = generate_onnx_representation(
        model_or_model_path, output_path=custom_output_path
    )

    if quantized:
        # Step 2. (recommended) quantize the converted model for fast inference and to reduce model size.
        quant_model_paths = quantize(onnx_model_paths)

        # step 3. setup onnx runtime
        print("Setting up onnx model...")
        model_sessions = get_onnx_runtime_sessions(quant_model_paths)
    else:
        print("Setting up onnx model...")
        model_sessions = get_onnx_runtime_sessions(onnx_model_paths)

    # step 4. get the onnx model
    model = OnnxT5(model_or_model_path, model_sessions)
    print("Done!")

    return model


def get_onnx_model(model_name, onnx_models_path=saved_models_path, quantized=True):
    """
    method gets the onnx model, if already converted models exists
    Example:
    >> get_onnx_model(model_name="t5-finetuned", onnx_models_path="../models/onnx/quantized/")

    """

    encoder_path, decoder_path, init_decoder_path = get_model_paths(
        model_name, Path(onnx_models_path), quantized
    )

    if quantized:
        assert (
            encoder_path.exists()
            and decoder_path.exists()
            and init_decoder_path.exists()
        ), "quantized model don't exist in the model folder, first quantize the model!"
    else:
        assert (
            encoder_path.exists()
            and decoder_path.exists()
            and init_decoder_path.exists()
        ), "all or some models don't exists in the model folder, first convert the model! "

    model_paths = encoder_path, decoder_path, init_decoder_path

    model_sessions = get_onnx_runtime_sessions(model_paths)

    model = OnnxT5(model_name, model_sessions)

    return model