""" ChatGLM model configuration """
import torch

from collections import OrderedDict
from typing import List, Mapping, Optional, Any

from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging

from transformers.onnx import OnnxConfigWithPast, PatchingSpec
from transformers import PreTrainedTokenizer, TensorType, is_torch_available

logger = logging.get_logger(__name__)


class ChatGLMConfig(PretrainedConfig):
    r"""
    This is the configuration class to store the configuration of a [`~ChatGLMModel`].
    It is used to instantiate an ChatGLM model according to the specified arguments, defining the model
    architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of
    the ChatGLM-6B [THUDM/ChatGLM-6B](https://huggingface.co/THUDM/chatglm-6b) architecture.

    Configuration objects inherit from  [`PretrainedConfig`] and can be used
    to control the model outputs. Read the documentation from  [`PretrainedConfig`]
    for more information.


    Args:
        vocab_size (`int`, *optional*, defaults to 150528):
            Vocabulary size of the ChatGLM-6B model. Defines the number of different tokens that can be represented by the
            `inputs_ids` passed when calling [`~ChatGLMModel`] or
            [`~TFChatGLMModel`].
        hidden_size (`int`, *optional*, defaults to 4096):
            Dimension of the encoder layers and the pooler layer.
        num_hidden_layers (`int`, *optional*, defaults to 28):
            Number of hidden layers in the Transformer encoder.
        num_attention_heads (`int`, *optional*, defaults to 32):
            Number of attention heads for each attention layer in the Transformer encoder.
        inner_hidden_size (`int`, *optional*, defaults to 16384):
            Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
        max_sequence_length (`int`, *optional*, defaults to 512):
            The maximum sequence length that this model might ever be used with.
            Typically set this to something large just in case (e.g., 512 or 1024 or 2048).
        layernorm_epsilon (`float`, *optional*, defaults to 1e-5):
            The epsilon used by the layer normalization layers.
        use_cache (`bool`, *optional*, defaults to `True`):
            Whether the model should return the last key/values attentions (not used by all models).
        Example:

    ```python
    >>> from configuration_chatglm import ChatGLMConfig
    >>> from modeling_chatglm import ChatGLMModel

    >>> # Initializing a ChatGLM-6B THUDM/ChatGLM-6B style configuration
    >>> configuration = ChatGLMConfig()

    >>> # Initializing a model from the THUDM/ChatGLM-6B style configuration
    >>> model = ChatGLMModel(configuration)

    >>> # Accessing the model configuration
    >>> configuration = model.config
    ```
"""
    model_type = "chatglm"

    def __init__(
            self,
            vocab_size=150528,
            hidden_size=4096,
            num_layers=28,
            num_attention_heads=32,
            layernorm_epsilon=1e-5,
            use_cache=False,
            bos_token_id=150004,
            eos_token_id=150005,
            mask_token_id=150000,
            gmask_token_id=150001,
            pad_token_id=0,
            max_sequence_length=2048,
            inner_hidden_size=16384,
            position_encoding_2d=True,
            quantization_bit=0,
            pre_seq_len=None,
            prefix_projection=False,
            **kwargs
    ):
        self.num_layers = num_layers
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        self.num_attention_heads = num_attention_heads
        self.max_sequence_length = max_sequence_length
        self.layernorm_epsilon = layernorm_epsilon
        self.inner_hidden_size = inner_hidden_size
        self.use_cache = use_cache
        self.bos_token_id = bos_token_id
        self.eos_token_id = eos_token_id
        self.pad_token_id = pad_token_id
        self.mask_token_id = mask_token_id
        self.gmask_token_id = gmask_token_id
        self.position_encoding_2d = position_encoding_2d
        self.quantization_bit = quantization_bit
        self.pre_seq_len = pre_seq_len
        self.prefix_projection = prefix_projection

        super().__init__(
            pad_token_id=pad_token_id,
            bos_token_id=bos_token_id,
            eos_token_id=eos_token_id,
            **kwargs
        )


class ChatGLMOnnxConfig(OnnxConfigWithPast):
    r"""
    This class is the custom configuration of a ChatGLMModel needed in exporting model to ONNX.
    Currently this need to pre-fix several model struct in modeling_chatglm.py

    Also there is still a TODO list of current ChatGLMOnnxConfig:
    1. add support for batch_size > 1
    2. add support for use_past

    in modeling_chatglm.py and its attention_fn function,we need to change several view into
    torch tensor action since reshape param may get frozen into constant in onnx model.
    here is the code:
    ```python
    >>> def attention_fn(
    >>>         self,
    >>>         query_layer,
    >>>         key_layer,
    >>>         value_layer,
    >>>         attention_mask,
    >>>         hidden_size_per_partition,
    >>>         layer_id,
    >>>         layer_past=None,
    >>>         scaling_attention_score=True,
    >>>         use_cache=False,
    >>> ):
    >>>     if layer_past is not None:
    >>>         past_key, past_value = layer_past[0], layer_past[1]
    >>>         key_layer = torch.cat((past_key, key_layer), dim=0)
    >>>         value_layer = torch.cat((past_value, value_layer), dim=0)
    >>>
    >>>     # seqlen, batch, num_attention_heads, hidden_size_per_attention_head
    >>>     seq_len, b, nh, hidden_size = key_layer.shape
    >>>
    >>>     if use_cache:
    >>>         present = (key_layer, value_layer)
    >>>     else:
    >>>         present = None
    >>>
    >>>     query_key_layer_scaling_coeff = float(layer_id + 1)
    >>>     if scaling_attention_score:
    >>>         query_layer = query_layer / (math.sqrt(hidden_size) * query_key_layer_scaling_coeff)
    >>>
    >>>     # ===================================
    >>>     # Raw attention scores. [b, np, s, s]
    >>>     # ===================================
    >>>
    >>>     # [b, np, sq, sk]
    >>>     # # output_size = (query_layer.size(1), query_layer.size(2), query_layer.size(0), key_layer.size(0))
    >>>
    >>>     # [sq, b, np, hn] -> [sq, b * np, hn]
    >>>     # query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1)
    >>>     query_layer = query_layer.flatten(start_dim=1, end_dim=2)
    >>>
    >>>     # [sk, b, np, hn] -> [sk, b * np, hn]
    >>>     # key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1)
    >>>     key_layer = key_layer.flatten(start_dim=1, end_dim=2)
    >>>
    >>>     matmul_result = torch.zeros(
    >>>         1, 1, 1,
    >>>         dtype=query_layer.dtype,
    >>>         device=query_layer.device,
    >>>     )
    >>>
    >>>     matmul_result = torch.baddbmm(
    >>>         matmul_result,
    >>>         query_layer.transpose(0, 1),  # [b * np, sq, hn]
    >>>         key_layer.transpose(0, 1).transpose(1, 2),  # [b * np, hn, sk]
    >>>         beta=0.0,
    >>>         alpha=1.0,
    >>>     )
    >>>
    >>>     # [b * np, sq, sk] -> [b, np, sq, sk]
    >>>     # attention_scores = matmul_result.view(*output_size)
    >>>     attention_scores = matmul_result.unsqueeze(0)
    >>>
    >>>     if self.scale_mask_softmax:
    >>>         self.scale_mask_softmax.scale = query_key_layer_scaling_coeff
    >>>         attention_probs = self.scale_mask_softmax(attention_scores, attention_mask.contiguous())
    >>>     else:
    >>>         # if not (attention_mask == 0).all():
    >>>         #     # if auto-regressive, skip
    >>>         attention_scores.masked_fill_(attention_mask, -10000.0)
    >>>         dtype = attention_scores.dtype
    >>>         attention_scores = attention_scores.float()
    >>>         attention_scores = attention_scores * query_key_layer_scaling_coeff
    >>>
    >>>         attention_probs = F.softmax(attention_scores, dim=-1)
    >>>
    >>>         attention_probs = attention_probs.type(dtype)
    >>>
    >>>     # =========================
    >>>     # Context layer. [sq, b, hp]
    >>>     # =========================
    >>>
    >>>     # value_layer -> context layer.
    >>>     # [sk, b, np, hn] --> [b, np, sq, hn]
    >>>
    >>>     # context layer shape: [b, np, sq, hn]
    >>>     # output_size = (value_layer.size(1), value_layer.size(2), query_layer.size(0), value_layer.size(3))
    >>>
    >>>     # change view [sk, b * np, hn]
    >>>     # value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1)
    >>>     value_layer = value_layer.flatten(start_dim=1, end_dim=2)
    >>>
    >>>     # change view [b * np, sq, sk]
    >>>     # attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1)
    >>>     attention_probs = attention_probs.flatten(start_dim=0, end_dim=1)
    >>>
    >>>     # matmul: [b * np, sq, hn]
    >>>     context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))
    >>>
    >>>     # change view [b, np, sq, hn]
    >>>     # context_layer = context_layer.reshape(b, np, sq, hidden_size)
    >>>     context_layer = context_layer.unsqueeze(0)
    >>>
    >>>     # [b, np, sq, hn] --> [sq, b, np, hn]
    >>>     context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
    >>>
    >>>     # [sq, b, np, hn] --> [sq, b, hp]
    >>>     # new_context_layer_shape = context_layer.size()[:-2] + (hidden_size_per_partition,)
    >>>     # context_layer = context_layer.view(*new_context_layer_shape)
    >>>     context_layer = context_layer.flatten(start_dim=2)
    >>>
    >>>     outputs = (context_layer, present, attention_probs)
    >>>
    >>>     return outputs
    '''
    mainly aviod using view with dynamic size

    after change the modeling_chatglm.py, you can simply use following code to export and test the onnx model
    ```python
    >>> from pathlib import Path
    >>> from transformers import AutoTokenizer, AutoModel
    >>> from transformers.onnx import export, validate_model_outputs
    >>>
    >>> # load model
    >>> tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
    >>> pt_model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
    >>> pt_model = pt_model.float()  # only tested in CPU for now
    >>> pt_model.eval()
    >>> # define path for saving onnx model
    >>> onnx_path = Path(f"model/chatglm-6b.onnx")
    >>> onnx_path.parent.mkdir(exist_ok=True)
    >>> # convert model to onnx
    >>> onnx_config_chatglm = ChatGLMOnnxConfig(pt_model.config, task="causal-lm")
    >>> onnx_inputs, onnx_outputs = export(tokenizer, pt_model,
    >>>                                    onnx_config_chatglm, onnx_config_chatglm.default_onnx_opset,
    >>>                                    onnx_path)
    >>> # test onnx model
    >>> validate_model_outputs(onnx_config_chatglm, tokenizer, pt_model, onnx_path, onnx_outputs, atol=1e-4)
    ```
    """
    # TODO support dynamic batch size
    default_fixed_batch = 1

    def __init__(
        self,
        config: PretrainedConfig,
        task: str = "default",
        patching_specs: List[PatchingSpec] = None,
        use_past: bool = False,
    ):
        super().__init__(config, task=task, patching_specs=patching_specs, use_past=use_past)

    @property
    def inputs(self) -> Mapping[str, Mapping[int, str]]:
        common_inputs = OrderedDict({"input_ids": {0: "batch", 1: "sequence"}})
        if self.use_past:
            # TODO support use_past
            # self.fill_with_past_key_values_(common_inputs, direction="inputs")
            # common_inputs["attention_mask"] = \
            #     {0: "batch", 1: "past_sequence + sequence", 2: "past_sequence + sequence"}
            raise NotImplementedError('position_ids do not support past_key_values yet.')
        else:
            # remind the order
            common_inputs["position_ids"] = {0: "batch", 2: "sequence"}
            common_inputs["attention_mask"] = {0: "batch", 2: "sequence", 3: "sequence"}

        return common_inputs

    @property
    def num_layers(self) -> int:
        return self._config.n_layer

    @property
    def num_attention_heads(self) -> int:
        return self._config.n_head

    def get_masks(self, input_ids, device=None):
        """
        reference from modeling_chatglm.get_masks
        """
        batch_size, seq_length = input_ids.shape
        context_lengths = [seq.tolist().index(self._config.bos_token_id) for seq in input_ids]
        if device:
            attention_mask = torch.ones((batch_size, seq_length, seq_length), device=device)
        else:
            attention_mask = torch.ones((batch_size, seq_length, seq_length), device=input_ids.device)
        attention_mask.tril_()
        for i, context_length in enumerate(context_lengths):
            attention_mask[i, :, :context_length] = 1
        attention_mask.unsqueeze_(1)
        attention_mask = (attention_mask < 0.5).bool()

        # print("attention_mask", attention_mask.shape)
        return attention_mask

    def get_position_ids(self, input_ids, mask_positions, device=None, use_gmasks=None):
        batch_size, seq_length = input_ids.shape
        if device is None:
            device = input_ids.device
        if use_gmasks is None:
            use_gmasks = [False] * batch_size
        context_lengths = [seq.tolist().index(self._config.bos_token_id) for seq in input_ids]
        if self._config.position_encoding_2d:
            position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
            for i, context_length in enumerate(context_lengths):
                position_ids[i, context_length:] = mask_positions[i]
            block_position_ids = [torch.cat((
                torch.zeros(context_length, dtype=torch.long, device=device),
                torch.arange(seq_length - context_length, dtype=torch.long, device=device) + 1
            )) for context_length in context_lengths]
            block_position_ids = torch.stack(block_position_ids, dim=0)
            position_ids = torch.stack((position_ids, block_position_ids), dim=1)
        else:
            position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
            for i, context_length in enumerate(context_lengths):
                if not use_gmasks[i]:
                    position_ids[context_length:] = mask_positions[i]

        # print("position_ids", position_ids.shape)
        return position_ids

    def generate_dummy_inputs(
        self,
        tokenizer: PreTrainedTokenizer,
        batch_size: int = default_fixed_batch,
        seq_length: int = -1,
        is_pair: bool = False,
        framework: Optional[TensorType] = None,
    ) -> Mapping[str, Any]:
        common_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs(
            tokenizer, batch_size=self.default_fixed_batch, seq_length=seq_length, is_pair=is_pair, framework=framework
        )
        # check if the mode is using fixed batch size
        if batch_size != self.default_fixed_batch:
            logger.warning('batch size is not fixed, force change into fixed batch size: %d.'
                           % self.default_fixed_batch)

        # We need to order the input in the way they appears in the forward()
        ordered_inputs = OrderedDict({"input_ids": common_inputs["input_ids"]})

        # Need to add the past_keys
        if self.use_past:
            if not is_torch_available():
                raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.")
            else:
                # TODO support use_past
                # import torch
                #
                # batch, seqlen = common_inputs["input_ids"].shape
                # # Not using the same length for past_key_values
                # past_key_values_length = seqlen + 2
                # past_shape = (
                #     batch,
                #     self.num_attention_heads,
                #     past_key_values_length,
                #     self._config.hidden_size // self.num_attention_heads,
                # )
                # ordered_inputs["past_key_values"] = [
                #     (torch.zeros(past_shape), torch.zeros(past_shape)) for _ in range(self.num_layers)
                # ]
                raise NotImplementedError('position_ids do not support past_key_values yet.')

        # Need to add the attention_mask manually
        # 1. add attention_mask
        ordered_inputs["attention_mask"] = self.get_masks(common_inputs["input_ids"])
        # 2. add position_ids
        MASK, gMASK = self._config.mask_token_id, self._config.gmask_token_id
        seqs = common_inputs["input_ids"].tolist()
        mask_positions, use_gmasks = [], []
        for seq in seqs:
            mask_token = gMASK if gMASK in seq else MASK
            use_gmask = mask_token == gMASK
            mask_positions.append(seq.index(mask_token))
            use_gmasks.append(use_gmask)
        ordered_inputs["position_ids"] = self.get_position_ids(common_inputs["input_ids"],
                                                               mask_positions, use_gmasks=use_gmasks)

        if self.use_past:
            # mask_dtype = ordered_inputs["attention_mask"].dtype
            # ordered_inputs["attention_mask"] = torch.cat(
            #     [ordered_inputs["attention_mask"], torch.ones(batch, past_key_values_length, dtype=mask_dtype)], dim=1
            # )
            raise NotImplementedError('position_ids do not support past_key_values yet.')

        return ordered_inputs

    @property
    def default_onnx_opset(self) -> int:
        return 13