DeCRED-base / e_branchformer.py
Lakoc's picture
Upload JointCTCAttentionEncoderDecoder
9b4bf4d verified
raw
history blame
8.91 kB
""" PyTorch Wav2Vec2-Ebranchformer model."""
from typing import Optional
import torch
import torch.utils.checkpoint
from torch import nn
from transformers.activations import ACT2FN
from transformers.models.wav2vec2.modeling_wav2vec2 import (
Wav2Vec2Config,
Wav2Vec2ForCTC,
Wav2Vec2ForPreTraining,
)
from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer import (
Wav2Vec2ConformerConfig,
Wav2Vec2ConformerEncoder,
)
from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer import (
Wav2Vec2ConformerFeedForward as Wav2Vec2EBranchformerFeedForward,
)
from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer import (
Wav2Vec2ConformerModel,
)
from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer import (
Wav2Vec2ConformerSelfAttention as Wav2Vec2EBranchformerSelfAttention,
)
from transformers.utils import logging
logger = logging.get_logger(__name__)
class Wav2Vec2EBranchformerConfig(Wav2Vec2ConformerConfig, Wav2Vec2Config):
"""Config for EBranhformer model extending conformer."""
model_type = "wav2vec2-ebranchformer"
def __init__(
self,
ebranchformer_conv_dropout=0.1,
csgu_activation="identity",
csgu_kernel_size=31,
csgu_use_linear_after_conv=False,
merge_conv_kernel=31,
use_macaron_ff=True,
**kwargs,
):
super().__init__(**kwargs)
# EBranchformer related params
self.csgu_kernel_size = csgu_kernel_size
self.csgu_activation = csgu_activation
self.csgu_conv_dropout = ebranchformer_conv_dropout
self.csgu_use_linear_after_conv = csgu_use_linear_after_conv
self.merge_conv_kernel = merge_conv_kernel
self.use_macaron_ff = use_macaron_ff
class ConvolutionalSpatialGatingUnit(torch.nn.Module):
"""Convolutional Spatial Gating Unit (CSGU)."""
def __init__(self, config: Wav2Vec2EBranchformerConfig):
super().__init__()
n_channels = config.intermediate_size // 2 # split input channels
self.norm = torch.nn.LayerNorm(n_channels)
self.conv = torch.nn.Conv1d(
n_channels,
n_channels,
config.csgu_kernel_size,
1,
(config.csgu_kernel_size - 1) // 2,
groups=n_channels,
)
if config.csgu_use_linear_after_conv:
self.linear = torch.nn.Linear(n_channels, n_channels)
else:
self.linear = None
if config.csgu_activation == "identity":
self.act = torch.nn.Identity()
else:
self.act = ACT2FN[config.csgu_activation]
self.dropout = torch.nn.Dropout(config.csgu_conv_dropout)
def forward(self, hidden_states: torch.FloatTensor):
"""Forward method
Args:
hidden_states (torch.Tensor): (N, T, D)
Returns:
out (torch.Tensor): (N, T, D/2)
"""
x_r, x_g = hidden_states.chunk(2, dim=-1)
x_g = self.norm(x_g) # (N, T, D/2)
x_g = self.conv(x_g.transpose(1, 2)).transpose(1, 2) # (N, T, D/2)
if self.linear is not None:
x_g = self.linear(x_g)
x_g = self.act(x_g)
hidden_states = x_r * x_g # (N, T, D/2)
hidden_states = self.dropout(hidden_states)
return hidden_states
class ConvolutionalGatingMLP(torch.nn.Module):
"""Convolutional Gating MLP (cgMLP)."""
def __init__(self, config: Wav2Vec2EBranchformerConfig):
super().__init__()
self.channel_proj1 = torch.nn.Sequential(
torch.nn.Linear(config.hidden_size, config.intermediate_size), torch.nn.GELU()
)
self.csgu = ConvolutionalSpatialGatingUnit(config)
self.channel_proj2 = torch.nn.Linear(config.intermediate_size // 2, config.hidden_size)
def forward(self, hidden_states: torch.FloatTensor):
hidden_states = self.channel_proj1(hidden_states) # hidden_size -> intermediate_size
hidden_states = self.csgu(hidden_states) # intermediate_size -> intermediate_size/2
hidden_states = self.channel_proj2(hidden_states) # intermediate_size/2 -> hidden_size
return hidden_states
class Wav2Vec2EBranchformerEncoderLayer(nn.Module):
def __init__(self, config: Wav2Vec2EBranchformerConfig):
super().__init__()
embed_dim = config.hidden_size
dropout = config.attention_dropout
# Feed-forward 1
if config.use_macaron_ff:
self.ff1 = nn.Sequential(nn.LayerNorm(embed_dim), Wav2Vec2EBranchformerFeedForward(config))
# Self-Attention
self.self_attn_layer_norm = nn.LayerNorm(embed_dim)
self.self_attn_dropout = torch.nn.Dropout(dropout)
self.self_attn = Wav2Vec2EBranchformerSelfAttention(config)
# cgMLP
self.cgMLP = ConvolutionalGatingMLP(config)
self.cgMLP_layer_norm = nn.LayerNorm(config.hidden_size)
self.cgMLP_dropout = torch.nn.Dropout(dropout)
# Merge
self.final_dropout = torch.nn.Dropout(dropout)
self.merge_proj = torch.nn.Linear(embed_dim + embed_dim, embed_dim)
self.depthwise_conv_fusion = torch.nn.Conv1d(
embed_dim + embed_dim,
embed_dim + embed_dim,
kernel_size=config.merge_conv_kernel,
stride=1,
padding=(config.merge_conv_kernel - 1) // 2,
groups=embed_dim + embed_dim,
bias=True,
)
self.final_layer_norm = nn.LayerNorm(embed_dim)
# Feed-forward 2
if config.use_macaron_ff:
self.ff2 = nn.Sequential(nn.LayerNorm(embed_dim), Wav2Vec2EBranchformerFeedForward(config))
def forward(
self,
hidden_states: torch.FloatTensor,
attention_mask: Optional[torch.Tensor] = None,
relative_position_embeddings: Optional[torch.Tensor] = None,
output_attentions: bool = False,
):
# 1. Optional ff1
if self.ff1:
residual = hidden_states
hidden_states = residual + 0.5 * self.ff1(hidden_states)
# 2. Split input to three branches
residual = hidden_states
global_branch = hidden_states
local_branch = hidden_states
# 3. Self-Attention branch
global_branch = self.self_attn_layer_norm(global_branch)
global_branch, attn_weigts = self.self_attn(
hidden_states=global_branch,
attention_mask=attention_mask,
relative_position_embeddings=relative_position_embeddings,
output_attentions=output_attentions,
)
global_branch = self.self_attn_dropout(global_branch)
# 4. cgMLP Branch
local_branch = self.cgMLP_layer_norm(local_branch)
local_branch = self.cgMLP(local_branch)
# 5. Merge operator
# a, concat
hidden_states = torch.cat([global_branch, local_branch], dim=-1)
merge_residual = hidden_states
# b, depth-wise conv mixing
hidden_states = merge_residual + self.depthwise_conv_fusion(hidden_states.transpose(1, 2)).transpose(1, 2)
# c, project back to original size and final dropout
hidden_states = self.final_dropout(self.merge_proj(hidden_states))
# 6. Add residual
hidden_states = residual + hidden_states
# 7. Optional ff2
if self.ff2:
residual = hidden_states
hidden_states = residual + 0.5 * self.ff2(hidden_states)
# 8. Final layer norm
hidden_states = self.final_layer_norm(hidden_states)
return hidden_states, attn_weigts
class Wav2Vec2EBranchformerEncoder(Wav2Vec2ConformerEncoder):
def __init__(self, config: Wav2Vec2EBranchformerConfig):
super().__init__(config)
self.layers = nn.ModuleList(
[Wav2Vec2EBranchformerEncoderLayer(config) for _ in range(config.num_hidden_layers)]
)
self.pos_conv_embed = None
class Wav2Vec2EBranchformerModel(Wav2Vec2ConformerModel):
def __init__(self, config: Wav2Vec2EBranchformerConfig):
super().__init__(config)
self.encoder = Wav2Vec2EBranchformerEncoder(config)
# Initialize weights and apply final processing
self.post_init()
class Wav2Vec2EBranchformerForPreTraining(Wav2Vec2ForPreTraining):
config_class = Wav2Vec2EBranchformerConfig
base_model_prefix = "wav2vec2"
def __init__(self, config: Wav2Vec2EBranchformerConfig):
super().__init__(config)
self.wav2vec2 = Wav2Vec2EBranchformerModel(config)
self.post_init()
class Wav2Vec2EBranchformerForCTC(Wav2Vec2ForCTC):
config_class = Wav2Vec2EBranchformerConfig
base_model_prefix = "wav2vec2"
def __init__(self, config: Wav2Vec2EBranchformerConfig):
super().__init__(config)
self.wav2vec2 = Wav2Vec2EBranchformerModel(config)
self.post_init()