# Copyright 2024 The YourMT3 Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Please see the details in the LICENSE file.
from dataclasses import dataclass
from typing import Optional, Tuple
import torch
from torch import nn
from transformers.utils import ModelOutput
from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_utils import PreTrainedModel
# from transformers.models.perceiver.modeling_perceiver import (PerceiverAbstractPositionEncoding,
#                                                               PerceiverTrainablePositionEncoding,
#                                                               PerceiverFourierPositionEncoding)


class PerceiverTFConfig(PretrainedConfig):
    r"""
    This is the configuration class to store the configuration of a [`PerceiverTF`]. It is used to instantiate an
    Perceiver model according to the specified arguments, defining the model architecture. Instantiating a
    configuration with the defaults will yield a similar configuration to that of the Perceiver
    [deepmind/language-perceiver](https://huggingface.co/deepmind/language-perceiver) architecture.

    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
    documentation from [`PretrainedConfig`] for more information.

    Args:
        num_latents (`int`, *optional*, defaults to 256):
            The number of latents.
        d_latents (`int`, *optional*, defaults to 1280):
            Dimension of the latent embeddings.
        d_model (`int`, *optional*, defaults to 768):
            Dimension of the inputs. Should only be provided in case [*PerceiverTextPreprocessor*] is used or no
            preprocessor is provided.
        kv_dim (`int`, *optional*, defaults to 128):
        num_blocks (`int`, *optional*, defaults to 1):
            Number of blocks in the Transformer encoder.
        num_self_attention_heads (`int`, *optional*, defaults to 8):
            Number of attention heads for each self-attention layer in the Transformer encoder.
        num_cross_attention_heads (`int`, *optional*, defaults to 8):
            Number of attention heads for each cross-attention layer in the Transformer encoder.
        num_local_transformers_per_block (`int`, *optional*, defaults to 2):
            Number of local Transformer layers per Transformer block in the Transformer encoder.
        num_temporal_transformers_per_block (`int`, *optional*, defaults to 2):
            Number of temporal Transformer layers per Transformer block in the Transformer encoder.
        shared_parallel_temporal_transformers (`bool`, *optional*, defaults to `False`):
            Whether to share the parameters across the K parallel temporal Transformers in each block.
        qk_channels (`int`, *optional*):
            Dimension to project the queries + keys before applying attention in the cross-attention and self-attention
            layers of the encoder. Will default to preserving the dimension of the queries if not specified.
        v_channels (`int`, *optional*):
            Dimension to project the values before applying attention in the cross-attention and self-attention layers
            of the encoder. Will default to preserving the dimension of the queries if not specified.
        ** DEPRECATED ** cross_attention_shape_for_attention (`str`, *optional*, defaults to `'kv'`):
            Dimension to use when downsampling the queries and keys in the cross-attention layer of the encoder.
        ** DEPRECATED ** self_attention_widening_factor (`int`, *optional*, defaults to 1):
            Dimension of the feed-forward layer in the cross-attention layer of the Transformer encoder.
        cross_attention_widening_factor (`int`, *optional*, defaults to 1):
            Dimension of the feed-forward layer in the self-attention layers of the Transformer encoder.
        hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
            The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
            `"relu"`, `"selu"` and `"gelu_new"` are supported.
        dropout_rate (`float`, *optional*, defaults to 0.1):
            The dropout ratio for the attention probabilities.
        initializer_range (`float`, *optional*, defaults to 0.02):
            The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
        layer_norm_type (`str`, *optional*, defaults to `'layer_norm'`):
            The type of layer normalization to use. Can be one of {'layer_norm', 'rms_norm'}.
        layer_norm_eps (`float`, *optional*, defaults to 1e-5):
            The epsilon used by the layer normalization layers.
        sca_use_query_residual (`bool`, *optional*, defaults to `True`):
            Whether to add a query residual in the spectral cross attention (SCA) layer of the encoder.
        use_query_residual (`float`, *optional*, defaults to `True`):
            Whether to add a query residual in the cross-attention layer of the encoder.
        position_encoding_type (`str`, *optional*, defaults to `'trainable'`):
            Type of position encoding to use. Can be one of {'trainable', 'alibi', 'alibit', 'rope', None}.
        num_max_positions (`int`, *optional*, defaults to 331):
            Maximum number of positions to use for the position encoding.
        vocab_size (`int`, *optional*, defaults to 262):
            Vocabulary size for the masked language modeling model.
        attention_to_channel (`bool`, defaults to `False`):
            Whether SCA should attend to the channel dimension. If False, attention to frequency bin dimension.
        ff_layer_type (`str`, *optional*, defaults to `'mlp'`):
            Type of feed-forward layer to use. Can be one of {'mlp', 'moe'}.
        ff_widening_factor (`int`, *optional*, defaults to 1):
            Widening factor for the feed-forward layers in the MLP/MoE.
        moe_num_experts (`int`, *optional*, defaults to 4):
            Number of experts to use in the mixture of experts (MoE) feed-forward layer. 
            Only used if `ff_layer_type` is set to `'moe'`.
        moe_topk (`int`, *optional*, defaults to 2):
            Number of top experts to use in the mixture of experts (MoE) feed-forward layer.
            Only used if `ff_layer_type` is set to `'moe'`.
        rope_type_sca (`str`, *optional*, defaults to `pixel`): Can be one of {'l'|lang', 'p'|'pixel', None}. 
            RoPE index type for SCA. Only used if `position_encoding_type` is set to `rope`.
        rope_type_latent (`str`, *optional*, defaults to `pixel`): Can be one of {'l'|'lang', 'p'|'pixel', None}.
            RoPE index type for Latent Transformer. Only used if `position_encoding_type` is set to `'rope'`.
        rope_type_temporal (`str`, *optional*, defaults to `lang`): Can be one of {'l'|'lang', 'p'|'pixel', None}.
            RoPE index type for Temporal Transformer. Only used if `position_encoding_type` is set to `'rope'`.     
        rope_apply_to_keys (`bool`, *optional*, defaults to `False`): Whether to apply RoPE to the keys in the
            self/cross-attention layers. Only used if `position_encoding_type` is set to `'rope'`.
        rope_partial_pe (`bool`, *optional*, defaults to `False`): Whether to use partial RoPE in the self/cross-attention.
            Only used if `position_encoding_type` is set to `'rope'`.
        rope_trainable (`bool`, *optional*, defaults to `False`): Whether to make the RoPE trainable. Only used if
    
    Example:

    ```python
    >>> from model.perceiver_mod import PerceiverTFEncodel, PerceiverTFConfig

    >>> # Initializing a Perceiver deepmind/language-perceiver style configuration
    >>> configuration = PerceiverTFConfig()

    >>> # Initializing a model from the deepmind/language-perceiver style configuration
    >>> model = PerceiverTFEncoder(configuration)

    >>> # Accessing the model configuration
    >>> configuration = model.config
    ```"""
    model_type = "perceivertf"

    def __init__(
        self,
        num_latents=24,
        d_latents=128,
        d_model=128,
        kv_dim=128,
        num_blocks=3,
        num_self_attention_heads=8,
        num_cross_attention_heads=8,
        num_local_transformers_per_block=2,
        num_temporal_transformers_per_block=2,
        qk_channels=128,
        v_channels=128,
        cross_attention_shape_for_attention="q",
        # self_attention_widening_factor=1, ** DEPRECATED **
        # cross_attention_widening_factor=1, ** DEPRECATED **
        hidden_act="gelu",
        dropout_rate=0.1,
        initializer_range=0.02,
        layer_norm_type="layer_norm",
        layer_norm_eps=1e-5,
        sca_use_query_residual=True,
        use_query_residual=True,
        position_encoding_type="trainable",
        num_max_positions=330,
        vocab_size=1391,
        attention_to_channel=False,
        ff_layer_type="mlp",
        ff_widening_factor=1,
        moe_num_experts=4,
        moe_topk=2,
        rope_type_sca="pixel",
        rope_type_latent="pixel",
        rope_type_temporal="lang",
        rope_apply_to_keys=False,
        rope_partial_pe=False,
        rope_trainable=False,
        **kwargs,
    ):
        super().__init__(**kwargs)

        self.num_latents = num_latents
        self.d_latents = d_latents
        self.d_model = d_model
        self.kv_dim = kv_dim
        self.qk_channels = qk_channels
        self.v_channels = v_channels

        self.num_blocks = num_blocks
        self.num_self_attention_heads = num_self_attention_heads
        self.num_cross_attention_heads = num_cross_attention_heads
        self.num_local_transformers_per_block = num_local_transformers_per_block
        self.num_temporal_transformers_per_block = num_temporal_transformers_per_block
        self.sca_use_query_residual = sca_use_query_residual
        self.use_query_residual = use_query_residual
        self.position_encoding_type = position_encoding_type
        self.num_max_positions = num_max_positions
        # self.self_attention_widening_factor = self_attention_widening_factor
        # self.cross_attention_widening_factor = cross_attention_widening_factor
        self.cross_attention_shape_for_attention = cross_attention_shape_for_attention
        self.attention_to_channel = attention_to_channel
        self.ff_layer_type = ff_layer_type
        self.ff_widening_factor = ff_widening_factor
        self.moe_num_experts = moe_num_experts
        self.moe_topk = moe_topk
        self.rope_type_sca = rope_type_sca
        self.rope_type_latent = rope_type_latent
        self.rope_type_temporal = rope_type_temporal
        self.rope_apply_to_keys = rope_apply_to_keys
        self.rope_partial_pe = rope_partial_pe
        self.rope_trainable = rope_trainable

        self.hidden_act = hidden_act
        self.dropout_rate = dropout_rate
        self.initializer_range = initializer_range
        self.layer_norm_type = layer_norm_type
        self.layer_norm_eps = layer_norm_eps

        # masked language modeling attributes
        self.vocab_size = vocab_size


class PerceiverTFPreTrainedModel(PreTrainedModel):
    """
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    """

    config_class = PerceiverTFConfig
    base_model_prefix = "perceivertf"
    main_input_name = "inputs"

    def _init_weights(self, module):
        """Initialize the weights"""
        if isinstance(module, (nn.Linear, nn.Conv2d)):
            # Slightly different from the TF version which uses truncated_normal for initialization
            # cf https://github.com/pytorch/pytorch/pull/5617
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            if module.bias is not None:
                module.bias.data.zero_()
        elif hasattr(module, "latents"):
            module.latents.data.normal_(mean=0.0, std=self.config.initializer_range)
        elif hasattr(module, "_pos_emb") and isinstance(module._pos_emb, nn.Parameter):
            # initialize PerceiverTFTrainablePE
            module._pos_emb.data.normal_(mean=0.0, std=self.config.initializer_range)
        elif hasattr(module, "_pos_emb_temporal"):
            # initialize PerceiverTFTrainablePE
            module._pos_emb_temporal.data.normal_(mean=0.0, std=self.config.initializer_range)
        elif hasattr(module, "slopes") and isinstance(module.slopes, nn.Parameter):
            # initialize AlibiPositionalBias
            module.reset_parameters()
        elif isinstance(module, nn.ParameterDict):
            for modality in module.keys():
                module[modality].data.normal_(mean=0.0, std=self.config.initializer_range)
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)
        # elif hasattr(module, "position_embeddings") and isinstance(
        #         module, PerceiverTrainablePositionEncoding):
        #     module.position_embeddings.data.normal_(mean=0.0, std=self.config.initializer_range)


# Replace the 'ModelOutputWithCrossAttentions' with 'MoEModelOutputWithCrossAttentions' for MoE
@dataclass
class MoEModelOutputWithCrossAttentions(ModelOutput):
    """
    Base class for model's outputs, with potential hidden states and attentions.
    Plus, router_probs for Mixture of Experts models.

    Args:
        last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
            Sequence of hidden-states at the output of the last layer of the model.
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
            one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
        cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
            weighted average in the cross-attention heads.
        router_probs (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_probs=True` and `config.add_router_probs=True` is passed or when `config.output_router_probs=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`.

            Raw router probabilities that are computed by MoE routers, these terms are used to compute the auxiliary
            loss and the z_loss for Mixture of Experts models.
    """

    last_hidden_state: torch.FloatTensor = None
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    attentions: Optional[Tuple[torch.FloatTensor]] = None
    cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
    router_logits: Optional[Tuple[torch.FloatTensor]] = None