YourMT3 / amt /src /model /perceiver_helper.py
mimbres's picture
.
a03c9b4
raw
history blame
16.4 kB
# Copyright 2024 The YourMT3 Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Please see the details in the LICENSE file.
from dataclasses import dataclass
from typing import Optional, Tuple
import torch
from torch import nn
from transformers.utils import ModelOutput
from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_utils import PreTrainedModel
# from transformers.models.perceiver.modeling_perceiver import (PerceiverAbstractPositionEncoding,
# PerceiverTrainablePositionEncoding,
# PerceiverFourierPositionEncoding)
class PerceiverTFConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`PerceiverTF`]. It is used to instantiate an
Perceiver model according to the specified arguments, defining the model architecture. Instantiating a
configuration with the defaults will yield a similar configuration to that of the Perceiver
[deepmind/language-perceiver](https://huggingface.co/deepmind/language-perceiver) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
num_latents (`int`, *optional*, defaults to 256):
The number of latents.
d_latents (`int`, *optional*, defaults to 1280):
Dimension of the latent embeddings.
d_model (`int`, *optional*, defaults to 768):
Dimension of the inputs. Should only be provided in case [*PerceiverTextPreprocessor*] is used or no
preprocessor is provided.
kv_dim (`int`, *optional*, defaults to 128):
num_blocks (`int`, *optional*, defaults to 1):
Number of blocks in the Transformer encoder.
num_self_attention_heads (`int`, *optional*, defaults to 8):
Number of attention heads for each self-attention layer in the Transformer encoder.
num_cross_attention_heads (`int`, *optional*, defaults to 8):
Number of attention heads for each cross-attention layer in the Transformer encoder.
num_local_transformers_per_block (`int`, *optional*, defaults to 2):
Number of local Transformer layers per Transformer block in the Transformer encoder.
num_temporal_transformers_per_block (`int`, *optional*, defaults to 2):
Number of temporal Transformer layers per Transformer block in the Transformer encoder.
shared_parallel_temporal_transformers (`bool`, *optional*, defaults to `False`):
Whether to share the parameters across the K parallel temporal Transformers in each block.
qk_channels (`int`, *optional*):
Dimension to project the queries + keys before applying attention in the cross-attention and self-attention
layers of the encoder. Will default to preserving the dimension of the queries if not specified.
v_channels (`int`, *optional*):
Dimension to project the values before applying attention in the cross-attention and self-attention layers
of the encoder. Will default to preserving the dimension of the queries if not specified.
** DEPRECATED ** cross_attention_shape_for_attention (`str`, *optional*, defaults to `'kv'`):
Dimension to use when downsampling the queries and keys in the cross-attention layer of the encoder.
** DEPRECATED ** self_attention_widening_factor (`int`, *optional*, defaults to 1):
Dimension of the feed-forward layer in the cross-attention layer of the Transformer encoder.
cross_attention_widening_factor (`int`, *optional*, defaults to 1):
Dimension of the feed-forward layer in the self-attention layers of the Transformer encoder.
hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
`"relu"`, `"selu"` and `"gelu_new"` are supported.
dropout_rate (`float`, *optional*, defaults to 0.1):
The dropout ratio for the attention probabilities.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
layer_norm_type (`str`, *optional*, defaults to `'layer_norm'`):
The type of layer normalization to use. Can be one of {'layer_norm', 'rms_norm'}.
layer_norm_eps (`float`, *optional*, defaults to 1e-5):
The epsilon used by the layer normalization layers.
sca_use_query_residual (`bool`, *optional*, defaults to `True`):
Whether to add a query residual in the spectral cross attention (SCA) layer of the encoder.
use_query_residual (`float`, *optional*, defaults to `True`):
Whether to add a query residual in the cross-attention layer of the encoder.
position_encoding_type (`str`, *optional*, defaults to `'trainable'`):
Type of position encoding to use. Can be one of {'trainable', 'alibi', 'alibit', 'rope', None}.
num_max_positions (`int`, *optional*, defaults to 331):
Maximum number of positions to use for the position encoding.
vocab_size (`int`, *optional*, defaults to 262):
Vocabulary size for the masked language modeling model.
attention_to_channel (`bool`, defaults to `False`):
Whether SCA should attend to the channel dimension. If False, attention to frequency bin dimension.
ff_layer_type (`str`, *optional*, defaults to `'mlp'`):
Type of feed-forward layer to use. Can be one of {'mlp', 'moe'}.
ff_widening_factor (`int`, *optional*, defaults to 1):
Widening factor for the feed-forward layers in the MLP/MoE.
moe_num_experts (`int`, *optional*, defaults to 4):
Number of experts to use in the mixture of experts (MoE) feed-forward layer.
Only used if `ff_layer_type` is set to `'moe'`.
moe_topk (`int`, *optional*, defaults to 2):
Number of top experts to use in the mixture of experts (MoE) feed-forward layer.
Only used if `ff_layer_type` is set to `'moe'`.
rope_type_sca (`str`, *optional*, defaults to `pixel`): Can be one of {'l'|lang', 'p'|'pixel', None}.
RoPE index type for SCA. Only used if `position_encoding_type` is set to `rope`.
rope_type_latent (`str`, *optional*, defaults to `pixel`): Can be one of {'l'|'lang', 'p'|'pixel', None}.
RoPE index type for Latent Transformer. Only used if `position_encoding_type` is set to `'rope'`.
rope_type_temporal (`str`, *optional*, defaults to `lang`): Can be one of {'l'|'lang', 'p'|'pixel', None}.
RoPE index type for Temporal Transformer. Only used if `position_encoding_type` is set to `'rope'`.
rope_apply_to_keys (`bool`, *optional*, defaults to `False`): Whether to apply RoPE to the keys in the
self/cross-attention layers. Only used if `position_encoding_type` is set to `'rope'`.
rope_partial_pe (`bool`, *optional*, defaults to `False`): Whether to use partial RoPE in the self/cross-attention.
Only used if `position_encoding_type` is set to `'rope'`.
rope_trainable (`bool`, *optional*, defaults to `False`): Whether to make the RoPE trainable. Only used if
Example:
```python
>>> from model.perceiver_mod import PerceiverTFEncodel, PerceiverTFConfig
>>> # Initializing a Perceiver deepmind/language-perceiver style configuration
>>> configuration = PerceiverTFConfig()
>>> # Initializing a model from the deepmind/language-perceiver style configuration
>>> model = PerceiverTFEncoder(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "perceivertf"
def __init__(
self,
num_latents=24,
d_latents=128,
d_model=128,
kv_dim=128,
num_blocks=3,
num_self_attention_heads=8,
num_cross_attention_heads=8,
num_local_transformers_per_block=2,
num_temporal_transformers_per_block=2,
qk_channels=128,
v_channels=128,
cross_attention_shape_for_attention="q",
# self_attention_widening_factor=1, ** DEPRECATED **
# cross_attention_widening_factor=1, ** DEPRECATED **
hidden_act="gelu",
dropout_rate=0.1,
initializer_range=0.02,
layer_norm_type="layer_norm",
layer_norm_eps=1e-5,
sca_use_query_residual=True,
use_query_residual=True,
position_encoding_type="trainable",
num_max_positions=330,
vocab_size=1391,
attention_to_channel=False,
ff_layer_type="mlp",
ff_widening_factor=1,
moe_num_experts=4,
moe_topk=2,
rope_type_sca="pixel",
rope_type_latent="pixel",
rope_type_temporal="lang",
rope_apply_to_keys=False,
rope_partial_pe=False,
rope_trainable=False,
**kwargs,
):
super().__init__(**kwargs)
self.num_latents = num_latents
self.d_latents = d_latents
self.d_model = d_model
self.kv_dim = kv_dim
self.qk_channels = qk_channels
self.v_channels = v_channels
self.num_blocks = num_blocks
self.num_self_attention_heads = num_self_attention_heads
self.num_cross_attention_heads = num_cross_attention_heads
self.num_local_transformers_per_block = num_local_transformers_per_block
self.num_temporal_transformers_per_block = num_temporal_transformers_per_block
self.sca_use_query_residual = sca_use_query_residual
self.use_query_residual = use_query_residual
self.position_encoding_type = position_encoding_type
self.num_max_positions = num_max_positions
# self.self_attention_widening_factor = self_attention_widening_factor
# self.cross_attention_widening_factor = cross_attention_widening_factor
self.cross_attention_shape_for_attention = cross_attention_shape_for_attention
self.attention_to_channel = attention_to_channel
self.ff_layer_type = ff_layer_type
self.ff_widening_factor = ff_widening_factor
self.moe_num_experts = moe_num_experts
self.moe_topk = moe_topk
self.rope_type_sca = rope_type_sca
self.rope_type_latent = rope_type_latent
self.rope_type_temporal = rope_type_temporal
self.rope_apply_to_keys = rope_apply_to_keys
self.rope_partial_pe = rope_partial_pe
self.rope_trainable = rope_trainable
self.hidden_act = hidden_act
self.dropout_rate = dropout_rate
self.initializer_range = initializer_range
self.layer_norm_type = layer_norm_type
self.layer_norm_eps = layer_norm_eps
# masked language modeling attributes
self.vocab_size = vocab_size
class PerceiverTFPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = PerceiverTFConfig
base_model_prefix = "perceivertf"
main_input_name = "inputs"
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, (nn.Linear, nn.Conv2d)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif hasattr(module, "latents"):
module.latents.data.normal_(mean=0.0, std=self.config.initializer_range)
elif hasattr(module, "_pos_emb") and isinstance(module._pos_emb, nn.Parameter):
# initialize PerceiverTFTrainablePE
module._pos_emb.data.normal_(mean=0.0, std=self.config.initializer_range)
elif hasattr(module, "_pos_emb_temporal"):
# initialize PerceiverTFTrainablePE
module._pos_emb_temporal.data.normal_(mean=0.0, std=self.config.initializer_range)
elif hasattr(module, "slopes") and isinstance(module.slopes, nn.Parameter):
# initialize AlibiPositionalBias
module.reset_parameters()
elif isinstance(module, nn.ParameterDict):
for modality in module.keys():
module[modality].data.normal_(mean=0.0, std=self.config.initializer_range)
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
# elif hasattr(module, "position_embeddings") and isinstance(
# module, PerceiverTrainablePositionEncoding):
# module.position_embeddings.data.normal_(mean=0.0, std=self.config.initializer_range)
# Replace the 'ModelOutputWithCrossAttentions' with 'MoEModelOutputWithCrossAttentions' for MoE
@dataclass
class MoEModelOutputWithCrossAttentions(ModelOutput):
"""
Base class for model's outputs, with potential hidden states and attentions.
Plus, router_probs for Mixture of Experts models.
Args:
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
weighted average in the cross-attention heads.
router_probs (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_probs=True` and `config.add_router_probs=True` is passed or when `config.output_router_probs=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`.
Raw router probabilities that are computed by MoE routers, these terms are used to compute the auxiliary
loss and the z_loss for Mixture of Experts models.
"""
last_hidden_state: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
router_logits: Optional[Tuple[torch.FloatTensor]] = None