Spaces:
Running
Running
# Copyright 2024 The YourMT3 Authors. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Please see the details in the LICENSE file. | |
"""perceiver_mod.py | |
Implementation of the PerceiverTF encoder with: | |
- AliBi positional bias | |
- Mixtral of Experts (MoE) feedforward layer | |
""" | |
import math | |
from einops import rearrange | |
from typing import Optional, Tuple, Union, List, Dict, Literal | |
import torch | |
from torch import nn | |
from transformers.models.perceiver.modeling_perceiver import PerceiverSelfOutput | |
from transformers.pytorch_utils import (apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer) | |
from model.perceiver_helper import MoEModelOutputWithCrossAttentions | |
from model.perceiver_helper import PerceiverTFPreTrainedModel, PerceiverTFConfig | |
from model.positional_encoding import AlibiPositionalBias, get_rotary_emb | |
from model.ops import get_layer_norm | |
from model.ff_layer import get_ff_layer | |
class PerceiverEmbeddings(nn.Module): | |
"""Construct the latent embeddings sharable with token embeddings in the decoder.""" | |
def __init__(self, config, shared_emb: Optional[nn.Parameter] = None): | |
super().__init__() | |
if shared_emb is not None: | |
self.latents = shared_emb | |
assert self.latents.shape == (config.num_latents, config.d_latents) | |
else: | |
self.latents = nn.Parameter(torch.randn(config.num_latents, config.d_latents)) | |
def forward(self, batch_size: int): | |
return self.latents.expand(batch_size, -1, -1) | |
class PerceiverTFTrainablePE(nn.Module): | |
"""Construct the trainable absolute positional embeddings.""" | |
def __init__(self, position_encoding_type: Literal['trainable', 'tkd', 'td', 'tk', 'kdt'], max_t: int, k: int, | |
d: int) -> None: | |
super().__init__() | |
self.position_encoding_type = position_encoding_type | |
self.max_t = max_t | |
self.k = k | |
self.d = d | |
if position_encoding_type in ['trainable', 'tkd']: | |
self._pos_emb = nn.Parameter(torch.randn(max_t, k, d)) | |
elif position_encoding_type == 'td': | |
self._pos_emb = nn.Parameter(torch.randn(max_t, d)) | |
elif position_encoding_type == 'tk': | |
self._pos_emb = nn.Parameter(torch.randn(max_t, k)) | |
elif position_encoding_type == 'kdt': | |
self._pos_emb = nn.Parameter(torch.randn(k, d)) | |
self._pos_emb_temporal = nn.Parameter(torch.randn(max_t, d)) | |
else: | |
raise ValueError(f'unknown position encoding type {position_encoding_type}') | |
def forward(self): | |
pos_emb_temporal = None | |
if self.position_encoding_type in ['trainable', 'tkd']: | |
pos_emb = self._pos_emb | |
elif self.position_encoding_type == 'td': | |
pos_emb = self._pos_emb.unsqueeze(1).expand(-1, self.k, -1) | |
elif self.position_encoding_type == 'tk': | |
pos_emb = self._pos_emb.unsqueeze(-1).expand(-1, -1, self.d) | |
elif self.position_encoding_type == 'kdt': | |
pos_emb = self._pos_emb.unsqueeze(0).expand(self.max_t, -1, -1) | |
pos_emb_temporal = self._pos_emb_temporal | |
return pos_emb, pos_emb_temporal | |
class PerceiverAlibiSelfAttention(nn.Module): | |
""" | |
Multi-headed {cross, self}-attention + Alibi/Rotary positional bias/emb: | |
- Can be used both in the encoder as well as in the decoder. | |
- Modified from PerceiverSelfAttention in modeling_perceiver.py to support Alibi positional bias | |
""" | |
def __init__( | |
self, | |
config, | |
is_cross_attention=False, | |
qk_channels=None, | |
v_channels=None, | |
num_heads=1, | |
q_dim=None, | |
kv_dim=None, | |
rotary_emb=None, | |
): | |
super().__init__() | |
self.num_heads = num_heads | |
# Q and K must have the same number of channels. | |
# Default to preserving Q's input's shape. | |
if qk_channels is None: | |
qk_channels = q_dim | |
# V's num_channels determines the shape of the output of QKV-attention. | |
# Default to the same number of channels used in the key-query operation. | |
if v_channels is None: | |
v_channels = qk_channels | |
if qk_channels % num_heads != 0: | |
raise ValueError(f"qk_channels ({qk_channels}) must be divisible by num_heads ({num_heads}).") | |
if v_channels % num_heads != 0: | |
raise ValueError(f"v_channels ({v_channels}) must be divisible by num_heads ({num_heads}).") | |
self.qk_channels = qk_channels | |
self.v_channels = v_channels | |
self.qk_channels_per_head = self.qk_channels // num_heads | |
self.v_channels_per_head = self.v_channels // num_heads | |
# Layer normalization | |
self.layernorm1 = get_layer_norm(q_dim, config.layer_norm_type, config.layer_norm_eps) | |
if is_cross_attention: | |
self.layernorm2 = get_layer_norm(kv_dim, config.layer_norm_type, config.layer_norm_eps) | |
else: | |
self.layernorm2 = nn.Identity() | |
# self.layernorm1 = nn.LayerNorm(q_dim) | |
# self.layernorm2 = nn.LayerNorm(kv_dim) if is_cross_attention else nn.Identity() | |
# Projection matrices | |
self.query = nn.Linear(q_dim, qk_channels) | |
self.key = nn.Linear(kv_dim, qk_channels) | |
self.value = nn.Linear(kv_dim, v_channels) | |
self.dropout = nn.Dropout(config.dropout_rate) | |
# (Modified) Alibi positional bias | |
if config.position_encoding_type == 'alibi': | |
self.alibi_bias = AlibiPositionalBias(heads=num_heads, total_heads=num_heads, trainable_slope=False) | |
elif config.position_encoding_type == 'alibit': | |
self.alibi_bias = AlibiPositionalBias(heads=num_heads, total_heads=num_heads, trainable_slope=True) | |
else: | |
self.alibi_bias = None | |
# (Modified) RoPE | |
if config.position_encoding_type == 'rope': | |
assert rotary_emb is not None, "rotary_emb must be provided for RoPE." | |
self.rotary_emb = rotary_emb | |
else: | |
self.rotary_emb = None | |
self.rope_apply_to_keys = config.rope_apply_to_keys # False by default | |
def transpose_for_scores(self, x, channels_per_head): | |
new_x_shape = x.size()[:-1] + (self.num_heads, channels_per_head) | |
x = x.view(*new_x_shape) | |
return x.permute(0, 2, 1, 3) | |
def forward( | |
self, | |
hidden_states: torch.Tensor, | |
attention_mask: Optional[torch.FloatTensor] = None, | |
head_mask: Optional[torch.FloatTensor] = None, | |
inputs: Optional[torch.FloatTensor] = None, | |
inputs_mask: Optional[torch.FloatTensor] = None, | |
output_attentions: Optional[bool] = False, | |
) -> Tuple[torch.Tensor]: | |
hidden_states = self.layernorm1(hidden_states) | |
inputs = self.layernorm2(inputs) | |
# Project queries, keys and values to a common feature dimension. If this is instantiated as a cross-attention module, | |
# the keys and values come from the inputs; the attention mask needs to be such that the inputs's non-relevant tokens are not attended to. | |
is_cross_attention = inputs is not None | |
queries = self.query(hidden_states) | |
if is_cross_attention: | |
keys = self.key(inputs) | |
values = self.value(inputs) | |
attention_mask = inputs_mask | |
else: | |
keys = self.key(hidden_states) | |
values = self.value(hidden_states) | |
# Reshape channels for multi-head attention. | |
# We reshape from (batch_size, time, channels) to (batch_size, num_heads, time, channels per head) | |
queries = self.transpose_for_scores(queries, self.qk_channels_per_head) | |
keys = self.transpose_for_scores(keys, self.qk_channels_per_head) | |
values = self.transpose_for_scores(values, self.v_channels_per_head) | |
# (Modified) RoPE | |
if self.rotary_emb is not None: | |
queries = self.rotary_emb.apply_rotary_custom(queries) | |
if self.rope_apply_to_keys is True: | |
keys = self.rotary_emb.apply_rotary_custom(keys) | |
# Take the dot product between the queries and keys to get the raw attention scores. | |
attention_scores = torch.matmul(queries, keys.transpose(-1, -2)) | |
# (Modified) Alibi positional bias | |
if self.alibi_bias is not None: | |
batch_size, num_heads, q_seq_len, k_seq_len = attention_scores.shape | |
attention_scores += self.alibi_bias(q_seq_len, | |
k_seq_len) # auto-broadcasting to (b, num_heads, q_seq_len, k_seq_len) | |
_, _, _, q_head_dim = queries.shape | |
_, _, _, v_head_dim = values.shape | |
hiddens = self.num_heads * v_head_dim | |
attention_scores = attention_scores / math.sqrt(q_head_dim) | |
if attention_mask is not None: | |
# Apply the attention mask (precomputed for all layers in PerceiverModel forward() function) | |
attention_scores = attention_scores + attention_mask | |
# Normalize the attention scores to probabilities. | |
attention_probs = nn.Softmax(dim=-1)(attention_scores) | |
# This is actually dropping out entire tokens to attend to, which might | |
# seem a bit unusual, but is taken from the original Transformer paper. | |
attention_probs = self.dropout(attention_probs) | |
# Mask heads if we want to | |
if head_mask is not None: | |
attention_probs = attention_probs * head_mask | |
context_layer = torch.matmul(attention_probs, values) | |
context_layer = context_layer.permute(0, 2, 1, 3).contiguous() | |
new_context_layer_shape = context_layer.size()[:-2] + (hiddens,) | |
context_layer = context_layer.view(*new_context_layer_shape) | |
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) | |
return outputs | |
class PerceiverAlibiAttention(nn.Module): | |
""" | |
Attention module, including a dense block + Alibi | |
: modified from PerceiverAttention in modeling_perceiver.py to support Alibi positional bias | |
""" | |
def __init__( | |
self, | |
config, | |
is_cross_attention=False, | |
qk_channels=None, | |
v_channels=None, | |
num_heads=1, | |
q_dim=None, | |
kv_dim=None, | |
use_query_residual=True, | |
rotary_emb=None, | |
): | |
super().__init__() | |
# MultiHead attention | |
if is_cross_attention and qk_channels is None: | |
if config.cross_attention_shape_for_attention == "q": | |
qk_channels = q_dim | |
elif config.cross_attention_shape_for_attention == "kv": | |
qk_channels = kv_dim | |
else: | |
raise ValueError(f"Unknown value {config.cross_attention_shape_for_attention} for " | |
"cross_attention_shape_for_attention.") | |
else: | |
if qk_channels is None: | |
qk_channels = q_dim | |
if v_channels is None: | |
v_channels = qk_channels | |
self.self = PerceiverAlibiSelfAttention(config, | |
is_cross_attention=is_cross_attention, | |
qk_channels=qk_channels, | |
v_channels=v_channels, | |
num_heads=num_heads, | |
q_dim=q_dim, | |
kv_dim=kv_dim, | |
rotary_emb=rotary_emb) | |
# dense block | |
output_channels = None | |
if is_cross_attention: | |
output_channels = q_dim | |
else: | |
if output_channels is None: | |
output_channels = v_channels | |
self.output = PerceiverSelfOutput(config, input_channels=self.self.v_channels, output_channels=output_channels) | |
self.use_query_residual = use_query_residual | |
self.pruned_heads = set() | |
def prune_heads(self, heads): | |
if len(heads) == 0: | |
return | |
heads, index = find_pruneable_heads_and_indices(heads, self.self.num_attention_heads, | |
self.self.attention_head_size, self.pruned_heads) | |
# Prune linear layers | |
self.self.query = prune_linear_layer(self.self.query, index) | |
self.self.key = prune_linear_layer(self.self.key, index) | |
self.self.value = prune_linear_layer(self.self.value, index) | |
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) | |
# Update hyper params and store pruned heads | |
self.self.num_attention_heads = self.self.num_attention_heads - len(heads) | |
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads | |
self.pruned_heads = self.pruned_heads.union(heads) | |
def forward( | |
self, | |
hidden_states: torch.Tensor, | |
attention_mask: Optional[torch.FloatTensor] = None, | |
head_mask: Optional[torch.FloatTensor] = None, | |
inputs: Optional[torch.FloatTensor] = None, | |
inputs_mask: Optional[torch.FloatTensor] = None, | |
output_attentions: Optional[bool] = False, | |
) -> Tuple[torch.Tensor]: | |
self_outputs = self.self( | |
hidden_states, | |
attention_mask, | |
head_mask, | |
inputs, | |
inputs_mask, | |
output_attentions, | |
) | |
# Output projection | |
attention_output = self.output(self_outputs[0]) | |
# Optionally include a residual to the original queries. | |
# Consider omitting the residual if the semantics of query and output | |
# are different, e.g. if queries are positions and outputs are pixels. | |
if self.use_query_residual: | |
attention_output = attention_output + hidden_states | |
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them | |
return outputs | |
class PerceiverAlibiLayer(nn.Module): | |
"""Construct a single PerceiverTF layer with: | |
- Alibi positional bias | |
- RoPE | |
- Mixtral of Experts (MoE) feedforward layer | |
""" | |
def __init__( | |
self, | |
config, | |
is_cross_attention=False, | |
qk_channels=None, | |
v_channels=None, | |
num_heads=1, | |
q_dim=None, | |
kv_dim=None, | |
widening_factor=1, | |
use_query_residual=True, | |
rotary_emb=None, | |
): | |
super().__init__() | |
self.chunk_size_feed_forward = config.chunk_size_feed_forward | |
self.seq_len_dim = 1 | |
self.attention = PerceiverAlibiAttention(config, | |
is_cross_attention=is_cross_attention, | |
qk_channels=qk_channels, | |
v_channels=v_channels, | |
num_heads=num_heads, | |
q_dim=q_dim, | |
kv_dim=kv_dim, | |
use_query_residual=use_query_residual, | |
rotary_emb=rotary_emb) | |
self.layernorm = get_layer_norm(q_dim, config.layer_norm_type, config.layer_norm_eps) | |
# self.layernorm = nn.LayerNorm(q_dim) | |
self.mlp = get_ff_layer(config, input_size=q_dim, widening_factor=widening_factor) | |
def forward( | |
self, | |
hidden_states: torch.Tensor, | |
attention_mask: Optional[torch.FloatTensor] = None, | |
head_mask: Optional[torch.FloatTensor] = None, | |
inputs: Optional[torch.FloatTensor] = None, | |
inputs_mask: Optional[torch.FloatTensor] = None, | |
output_attentions: Optional[bool] = False, | |
) -> Tuple[torch.Tensor]: | |
attention_outputs = self.attention( | |
hidden_states, | |
attention_mask, | |
head_mask, | |
inputs, | |
inputs_mask, | |
output_attentions, | |
) | |
attention_output = attention_outputs[0] | |
outputs = attention_outputs[1:] # add attentions if we output attention weights | |
"""apply_chunking_to_forward: | |
This function chunks the input_tensors into smaller input tensor parts of size | |
chunk_size over the dimension chunk_dim. It then applies a layer forward_fn to | |
each chunk independently to save memory.If the forward_fn is independent across | |
the chunk_dim this function will yield the same result as not applying it. | |
""" | |
layer_output, router_logits = apply_chunking_to_forward(self.feed_forward_chunk, self.chunk_size_feed_forward, | |
self.seq_len_dim, attention_output) | |
layer_output = layer_output + attention_output # residual connection | |
outputs = (layer_output,) + outputs + (router_logits,) # add router_logits to outputs | |
return outputs | |
def feed_forward_chunk(self, attention_output): | |
layer_output = self.layernorm(attention_output) | |
layer_output, router_logits = self.mlp(layer_output) # router_logits is returned only when using MoE. | |
return layer_output, router_logits | |
class PerceiverTFEncoderBlock(nn.Module): | |
"""Construct a single block of PerceiverTF encoder: | |
- Spectral Cross Attention (SCA) | |
- Local latent transformer layers | |
- Temporal transformer layers | |
- added Alibi positional bias, RoPE, gMLP and MoE feedforward layer | |
""" | |
def __init__(self, | |
config: PerceiverTFConfig, | |
kv_dim: Optional[int] = None, | |
sca_use_query_residual: bool = True, | |
rotary_emb_sca: Optional[nn.Module] = None, | |
rotary_emb_latent: Optional[nn.Module] = None, | |
rotary_emb_temporal: Optional[nn.Module] = None): | |
super().__init__() | |
self.config = config | |
# Check that we can use multihead-attention with these shapes. | |
if config.d_latents % config.num_self_attention_heads != 0: | |
raise ValueError(f"num_z_channels ({config.d_latents}) must be divisible by" | |
f" num_self_attend_heads ({config.num_self_attention_heads}).") | |
if config.d_latents % config.num_cross_attention_heads != 0: | |
raise ValueError(f"num_z_channels ({config.d_latents}) must be divisible by" | |
f" num_cross_attend_heads ({config.num_cross_attention_heads}).") | |
if kv_dim is None: | |
kv_dim = config.kv_dim | |
if sca_use_query_residual is None: | |
sca_use_query_residual = config.sca_use_query_residual | |
# Spectral Cross Attention (SCA) layer. | |
self.sca_attention_to_channel = config.attention_to_channel | |
self.spectral_cross_attention = PerceiverAlibiAttention(config, | |
is_cross_attention=True, | |
qk_channels=config.qk_channels, | |
v_channels=config.v_channels, | |
num_heads=config.num_cross_attention_heads, | |
q_dim=config.d_latents, | |
kv_dim=kv_dim, | |
use_query_residual=sca_use_query_residual, | |
rotary_emb=rotary_emb_sca) # (Modified) RoPE | |
# Local latent trasformer layers. | |
local_transformer_layers = [] | |
for _ in range(config.num_local_transformers_per_block): | |
layer = PerceiverAlibiLayer( | |
config, | |
is_cross_attention=False, | |
qk_channels=config.qk_channels, # projection dim for q and k. | |
v_channels=config.v_channels, # projection dim for v. | |
num_heads=config.num_self_attention_heads, | |
q_dim=config.d_model, | |
kv_dim=config.d_model, | |
widening_factor=config.ff_widening_factor, | |
use_query_residual=config.use_query_residual, | |
rotary_emb=rotary_emb_latent # (Modified) RoPE | |
) | |
local_transformer_layers.append(layer) | |
self.local_transformer = nn.ModuleList(local_transformer_layers) | |
# Temporal transformer layers. | |
temporal_transformer_layers = [] | |
for _ in range(config.num_temporal_transformers_per_block): | |
layer = PerceiverAlibiLayer( | |
config, | |
is_cross_attention=False, | |
qk_channels=config.qk_channels, # projection dim for q and k. | |
v_channels=config.v_channels, # projection dim for v. | |
num_heads=config.num_self_attention_heads, | |
q_dim=config.d_model, | |
kv_dim=config.d_model, | |
widening_factor=config.ff_widening_factor, | |
use_query_residual=config.use_query_residual, | |
rotary_emb=rotary_emb_temporal # (Modified) RoPE | |
) | |
temporal_transformer_layers.append(layer) | |
self.temporal_transformer = nn.ModuleList(temporal_transformer_layers) | |
def forward( | |
self, | |
hidden_states: torch.Tensor, | |
inputs: Optional[torch.FloatTensor] = None, | |
inputs_mask: Optional[torch.FloatTensor] = None, | |
local_attention_mask: Optional[torch.FloatTensor] = None, | |
temporal_attention_mask: Optional[torch.FloatTensor] = None, | |
local_head_mask: Optional[torch.FloatTensor] = None, | |
temporal_head_mask: Optional[torch.FloatTensor] = None, | |
pos_emb_temporal: Optional[torch.FloatTensor] = None, | |
output_attentions: Optional[bool] = False, | |
output_hidden_states: Optional[bool] = False, | |
output_router_logits: Optional[bool] = False, # Only used for MoE. | |
return_dict: Optional[bool] = True, | |
) -> Union[Tuple, MoEModelOutputWithCrossAttentions]: | |
""" | |
Inputs: | |
hidden_states: (B, T, K, D) | |
inputs: (B, T, F, C) | |
Returns: | |
hidden_states: (B, T, K, D) | |
Args: | |
hidden_states: | |
latent_array (B, T, num_latents, d_latents) for SCA. The latent array | |
with shape (B, K, D) is expanded by t, and positional embeddings are | |
added to it. | |
inputs: torch.FloatTensor | |
The input sequence of shape (B, T, F, C). | |
inputs_mask: torch.FloatTensor | |
Only used for SCA. By default, None. | |
local_attention_mask: | |
Used for local self-attention. By default, None. | |
temporal_attention_mask: | |
Used for temporal self-attention. By default, None. | |
local_head_mask: | |
By default, None. | |
temporal_head_mask: | |
By default, None. | |
pos_emb_temporal: | |
Optioanl. Used for temporal self-attention. By default, None. (max_t, num_latents, d_latents) | |
output_attentions: bool | |
Whether to return attentions weights. | |
output_hidden_states: bool | |
Whether to return all hidden states. If False, only last hidden | |
state is returned. | |
output_router_logits: bool | |
Whether to return router logits for MoE. If False, only last hidden | |
state is returned. | |
return_dict: bool | |
Whether to return a MoEModelOutputWithCrossAttentions instead of a tuple. | |
""" | |
all_hidden_states = () if output_hidden_states else None | |
all_self_attentions = () if output_attentions else None | |
all_cross_attentions = () if output_attentions else None | |
all_router_logits = () if output_router_logits else None | |
# Collect dimension info | |
batch_size, t, num_latents, d_latents = hidden_states.size() # (B, T, K, D) | |
# if self.sca_attention_to_channel: | |
# _, _, _, f = inputs.size() # (B, T, C, F) | |
# assert d_latents == f, "d_latents must be equal to kv_dim, which is input frequency dim." | |
# else: | |
# _, _, _, c = inputs.size() # (B, T, F, C) | |
# assert d_latents == c, "d_latents must be equal to kv_dim, which is input channels." | |
# Reshape (B, T, _, _) to (B*T, _, _) for SCA and local transformer. | |
hidden_states = rearrange(hidden_states, "b t k d -> (b t) k d") | |
inputs = rearrange(inputs, "b t f c -> (b t) f c") | |
# Apply the SCA between the latents (hidden_states) and inputs: | |
layer_outputs = self.spectral_cross_attention( | |
hidden_states, | |
attention_mask=None, # Input_mask is used instead for cross-attention | |
inputs=inputs, | |
inputs_mask=inputs_mask, | |
output_attentions=output_attentions, | |
) | |
hidden_states = layer_outputs[0] # (B*T, K, D) | |
if output_attentions: | |
all_cross_attentions = all_cross_attentions + (layer_outputs[1],) | |
# Apply the block of local latent transformer layers. | |
for i, layer_module in enumerate(self.local_transformer): | |
if output_hidden_states: | |
all_hidden_states = all_hidden_states + (hidden_states,) | |
layer_head_mask = local_head_mask[i] if local_head_mask is not None else None | |
layer_outputs = layer_module( | |
hidden_states, | |
attention_mask=local_attention_mask, | |
head_mask=layer_head_mask, | |
output_attentions=output_attentions, | |
) | |
hidden_states = layer_outputs[0] # (B*T, K, D) | |
if output_attentions: | |
all_self_attentions = all_self_attentions + (layer_outputs[1],) | |
if output_router_logits: | |
all_router_logits = all_router_logits + (layer_outputs[2],) | |
if output_hidden_states: | |
all_hidden_states = all_hidden_states + (hidden_states,) | |
# Reshape (B*T, K, D) to (B*K, T, D) for the temporal transformer. | |
hidden_states = rearrange(hidden_states, "(b t) k d -> (b k) t d", b=batch_size) | |
# Apply the block of temporal transformer layers. | |
for i, layer_module in enumerate(self.temporal_transformer): | |
if output_hidden_states: | |
all_hidden_states = all_hidden_states + (hidden_states,) | |
layer_head_mask = temporal_head_mask[i] if temporal_head_mask is not None else None | |
if i == 0 and pos_emb_temporal is not None: | |
# Add temporal positional embeddings to the hidden_states. | |
hidden_states = hidden_states + pos_emb_temporal[:t] # pos_emb_temporal: (T, D) | |
layer_outputs = layer_module( | |
hidden_states, | |
attention_mask=temporal_attention_mask, | |
head_mask=layer_head_mask, | |
output_attentions=output_attentions, | |
) | |
hidden_states = layer_outputs[0] | |
if output_attentions: | |
all_self_attentions = all_self_attentions + (layer_outputs[1],) | |
if output_router_logits: | |
all_router_logits = all_router_logits + (layer_outputs[2],) | |
if output_hidden_states: | |
all_hidden_states = all_hidden_states + (hidden_states,) | |
last_hideen_state = hidden_states | |
# Reshape (B*K, T, D) to (B, T, K, D) for the next block. | |
last_hideen_state = rearrange(last_hideen_state, "(b k) t d -> b t k d", b=batch_size) | |
# Prepare the outputs. | |
if not return_dict: | |
return tuple( | |
v for v in | |
[last_hideen_state, all_hidden_states, all_self_attentions, all_cross_attentions, all_router_logits] | |
if v is not None) | |
return MoEModelOutputWithCrossAttentions( | |
last_hidden_state=last_hideen_state, | |
hidden_states=all_hidden_states, | |
attentions=all_self_attentions, | |
cross_attentions=all_cross_attentions, | |
router_logits=all_router_logits, | |
) | |
class PerceiverTFEncoder(PerceiverTFPreTrainedModel): | |
"""PerceiverTFEncoder is an encoder model based on the Perceiver and Spectral Cross Attention (SCA). | |
position_encoding_type: str | |
The type of positional encoding to use. One of the following: | |
- 'trainable': trainable positional embeddings | |
- 'alibi': AlibiNet positional embeddings | |
- 'alibit': AlibiNet positional embeddings with trainable slopes for each head | |
- 'rope': RoPE (Rotary Positional Encoding) | |
(experimental w/ 'trainable') | |
- 'tkd': trainable PE (T,K,D) on latent (default for 'trainable') | |
- 'td': trainable PE (T,D) on latent | |
- 'tk': trainable PE (T,K) on latent | |
- 'kdt': trainable PE (K,D) on latent, and (T,) on temporal transformer | |
""" | |
def __init__(self, | |
config: PerceiverTFConfig, | |
sca_use_query_residual: Optional[bool] = None, | |
shared_emb: Optional[nn.Embedding] = None): | |
super().__init__(config) | |
self.config = config | |
if sca_use_query_residual is None: | |
self.sca_use_query_residual = config.sca_use_query_residual # True by default | |
self.position_encoding_type = config.position_encoding_type | |
self.sca_attention_to_channel = config.attention_to_channel | |
# Construct a latent array. | |
self.latent_array = PerceiverEmbeddings(config) # (num_latents, d_latents) | |
# Positional embeddings for the latent array. | |
if self.position_encoding_type == 'rope': | |
# (Modified) RoPE | |
self.rotary_emb_sca = get_rotary_emb(config.num_cross_attention_heads, config.rope_type_sca, | |
config.rope_partial_pe, config.rope_trainable) | |
self.rotary_emb_latent = get_rotary_emb(config.num_cross_attention_heads, config.rope_type_latent, | |
config.rope_partial_pe, config.rope_trainable) | |
self.rotary_emb_temporal = get_rotary_emb(config.num_cross_attention_heads, config.rope_type_temporal, | |
config.rope_partial_pe, config.rope_trainable) | |
else: | |
self.rotary_emb_sca = None | |
self.rotary_emb_latent = None | |
self.rotary_emb_temporal = None | |
if self.position_encoding_type in ['alibi', 'alibit', 'rope', None]: | |
# alibi is imeplemented within PerceiverAlibiSelfAttention, and activated by config. | |
# RoPE is implemented without using self.pos_emb. | |
self.pos_emb = None | |
else: | |
k, d = self.latent_array.latents.size() | |
max_t = int(config.num_max_positions) + 10 # 10 is headroom for future task tokens... | |
self.pos_emb = PerceiverTFTrainablePE(self.position_encoding_type, max_t, k, d) | |
""" | |
self.pos_emb() returns: | |
pos_emb: (max_t, K, D) | |
pos_emb_temporal: (max_t, K, D) | |
""" | |
# Construct the encoder blocks. | |
blocks = [] | |
for _ in range(config.num_blocks): | |
block = PerceiverTFEncoderBlock( | |
config, | |
kv_dim=config.kv_dim, | |
sca_use_query_residual=sca_use_query_residual, | |
rotary_emb_sca=self.rotary_emb_sca, # (Modified) RoPE | |
rotary_emb_latent=self.rotary_emb_latent, | |
rotary_emb_temporal=self.rotary_emb_temporal) | |
blocks.append(block) | |
self.blocks = nn.ModuleList(blocks) | |
# Initialize weights and apply final processing | |
self.post_init() | |
def get_input_embeddings(self): | |
return self.latent_array.latents | |
def set_input_embeddings(self, value): | |
self.latent_array.latents = value | |
"""temporary fix for torch.compile issue""" | |
def forward(self, **kwargs): | |
if self.training is True: | |
return self._forward_compile(**kwargs) | |
else: | |
return self._forward_no_compile(**kwargs) | |
def _forward_no_compile(self, **kwargs): | |
return self._forward(**kwargs) | |
def _forward_compile(self, **kwargs): | |
return self._forward(**kwargs) | |
def _forward( | |
self, | |
inputs: Optional[torch.FloatTensor] = None, # (B, T, F, kv_dim) | |
inputs_embeds: Optional[torch.FloatTensor] = None, # (B, T, F, kv_dim) | |
inputs_mask: Optional[torch.FloatTensor] = None, # (B, F) Mask freq. of inputs in SCA. | |
local_attention_mask: Optional[torch.FloatTensor] = None, # (B, K) | |
temporal_attention_mask: Optional[torch.FloatTensor] = None, # (B, T) | |
local_head_mask: Optional[torch.FloatTensor] = None, | |
temporal_head_mask: Optional[torch.FloatTensor] = None, | |
output_attentions: Optional[bool] = None, | |
output_hidden_states: Optional[bool] = None, | |
output_router_logits: Optional[bool] = None, | |
return_dict: Optional[bool] = None, | |
) -> Union[Tuple, MoEModelOutputWithCrossAttentions]: | |
# Inputs and inputs_embeds are tied, and actually the same. (following T5 convention) | |
# Inputs are from convoulutional features from audio. | |
# Don't be confused with latent embeddings, which is `self.latent_array.latents`, and | |
# used as hidden_state of block. | |
if inputs is None and inputs_embeds is not None: | |
inputs = inputs_embeds | |
elif inputs is None and inputs_embeds is None: | |
raise ValueError("You must provide 'inputs' or 'inputs_embeds' argument.") | |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | |
output_hidden_states = (output_hidden_states | |
if output_hidden_states is not None else self.config.output_hidden_states) | |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
batch_size, t, _f, _c = inputs.size() | |
device = inputs.device | |
# SCA attention to channels of inputs, instead of frequency bins. | |
if self.sca_attention_to_channel is True: | |
inputs = rearrange(inputs, "b t f c -> b t c f") | |
# Prepare head mask if needed | |
# 1.0 in head_mask indicate we keep the head | |
# attention_probs has shape bsz x n_heads x N x N | |
# input head_mask has shape [num_heads] or [num_blocks x num_heads] | |
# and head_mask is converted to shape [num_blocks x batch x num_heads x N x N] | |
local_head_mask = self.get_head_mask(local_head_mask, | |
self.config.num_blocks * self.config.num_local_transformers_per_block) | |
temporal_head_mask = self.get_head_mask( | |
temporal_head_mask, self.config.num_blocks * self.config.num_temporal_transformers_per_block) | |
# Prepare attention mask: not implemented | |
# Expand the latent embeddings by t: (B, K, D) --> (B, T, K, D) | |
latent_embeddings = self.latent_array(batch_size=batch_size) # (B, num_latents, d_latents) | |
expanded_latent_embeddings = latent_embeddings.unsqueeze(1).expand(-1, t, -1, -1) | |
# Add positional embeddings to the expanded latent embeddings: (B, T, K, D) | |
if self.pos_emb is not None: | |
pos_emb_latent, pos_emb_temporal = self.pos_emb.forward() | |
expanded_latent_embeddings = expanded_latent_embeddings + pos_emb_latent[:t] | |
# (max_t, K, D) -> (T, K, D) -> (B, T, K, D) auto-broadcasting | |
else: | |
pos_emb_temporal = None | |
# Lists to store intermediate outputs if required | |
all_hidden_states = [] | |
all_attentions = [] | |
all_cross_attentions = [] | |
all_router_logits = [] | |
hidden_states = expanded_latent_embeddings | |
# Forward-pass | |
for i, block in enumerate(self.blocks): | |
block_output = block(hidden_states=hidden_states, | |
inputs=inputs, | |
inputs_mask=inputs_mask, | |
local_attention_mask=local_attention_mask, | |
temporal_attention_mask=temporal_attention_mask, | |
local_head_mask=local_head_mask, | |
temporal_head_mask=temporal_head_mask, | |
pos_emb_temporal=pos_emb_temporal if i == 0 else None, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
output_router_logits=output_router_logits, | |
return_dict=True) | |
# Update the hidden_states for the next block | |
hidden_states = block_output.last_hidden_state | |
# Append to lists if required | |
if output_hidden_states: | |
all_hidden_states.append(hidden_states) | |
if output_attentions: | |
all_attentions.append(block_output.attentions) | |
all_cross_attentions.append(block_output.cross_attentions) | |
if output_router_logits: | |
all_router_logits.append(block_output.router_logits) | |
last_hidden_states = hidden_states | |
# Prepare outputs | |
if not return_dict: | |
# Convert lists to tuples | |
return (last_hidden_states, tuple(all_hidden_states) if all_hidden_states else None, | |
tuple(all_attentions) if all_attentions else None, | |
tuple(all_cross_attentions) if all_cross_attentions else None, | |
tuple(all_router_logits) if all_router_logits else None) | |
return MoEModelOutputWithCrossAttentions( | |
last_hidden_state=last_hidden_states, | |
hidden_states=tuple(all_hidden_states) if all_hidden_states else None, | |
attentions=tuple(all_attentions) if all_attentions else None, | |
cross_attentions=tuple(all_cross_attentions) if all_cross_attentions else None, | |
router_logits=tuple(all_router_logits) if all_router_logits else None) | |
def test(): | |
# In HuggingFace's Perceiver implementation: | |
# `q_dim` is the latent array dimension d_latents of ((B), num_latents, d_latents). | |
# `kv_dim`os the actual input dimension D of (B, T, D) | |
# `qk_channels`, `v_channels`: are projection dimensions for attention, (B, T, C) | |
# (B, T, D) --> projection --> (B, T, C) | |
# However, PerceiverTF does not require projection: | |
# It takes as input a latent tensor (B, num_latents, d_latents) and a conv_feat tensor (T, B, F, C) | |
# The `spectral-cross-attention` and `local-self-attention-transformer` takes as input (B*T, F, C), | |
# and C=D=d_latents. | |
from model.ops import count_parameters | |
# Test input | |
b = 2 # batch | |
t = 10 # time steps (330 for 6s in paper) | |
f = 128 # freq of conv_feat | |
c = 128 # channels of conv_feat | |
k = 24 # num_latents | |
d = 128 # d_latents | |
conv_feat = torch.randn(b, t, f, c) | |
# construct PerceiverTFEncoder | |
config = PerceiverTFConfig() | |
pe_types = ['alibi', 'alibit', 'trainable', 'tkd', 'td', 'tk', 'kdt', None] | |
config.ff_layer_type = 'moe' | |
config.moe_num_experts = 4 | |
config.moe_topk = 2 | |
for pe_type in pe_types: | |
config.position_encoding_type = pe_type # 'alibi', 'alibit', 'trainable', 'tkd', 'td', 'tk', 'kdt', None | |
config.num_latents = k | |
config.d_latents = d | |
config.kv_dim = c | |
config.qk_channels = d | |
config.v_channels = d | |
encoder = PerceiverTFEncoder(config) | |
encoder.eval() | |
assert encoder.latent_array.latents.size() == (k, d) | |
# forward | |
enc_hidden_state = encoder.forward(inputs_embeds=conv_feat).last_hidden_state | |
# print(enc_hidden_state.shape) # [2, 10, 24, 128] = [B, T, K, D] | |
n_param = count_parameters(encoder)[1] // 1000 | |
print(config.position_encoding_type, f'num_param: {n_param}K') | |
""" | |
PE type | num. param. | |
None | 1397K | |
alibi | 1397K | |
alibit (train slope) | 1397K | |
tkd | 2442K | |
td | 1441K | |
tk | 1405K | |
kdt | 1444K | |
MLP | 2637K | |
MoE (4 experts) | 4411K | |
MoE (6 experts) | 5594K | |
""" | |