# Copyright (c) Microsoft Corporation. # Licensed under the MIT license. #!/usr/bin/env python3 # activation_checkpointing.py """helper function for activation checkpointing""" from typing import Union, Dict, Callable from functools import partial from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( checkpoint_wrapper, offload_wrapper, CheckpointImpl, ) # utils.py """cascade basic blocks""" import math import backoff import random import numpy as np from typing import Optional, Tuple, Union import torch from torch import nn from torch import Tensor import torch.nn.functional as F # conformer_encoder.py """ConformerEncoder Module""" from typing import Optional, Tuple, List, Literal import abc import math import numpy as np import torch from torch import nn, Tensor from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import CheckpointWrapper from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel # activation_checkpointing.py def validate_checkpointing_config(activation_checkpointing): """validate activation checkpointing configuration""" if isinstance(activation_checkpointing, str): assert activation_checkpointing in ( "", "checkpoint", "offload", ), "activation_checkpointing has to be a dict or a str in ('', 'checkpoint', 'offload')." elif isinstance(activation_checkpointing, dict): assert activation_checkpointing.get("module", "transformer") in ( "transformer", "attention", ), "module in activation_checkpointing has to be in ('transformer', 'attention')." else: raise ValueError("activation_checkpointing has to be a str or dict.") def embedding_checkpoint_wrapper( activation_checkpointing: Union[str, Dict], ) -> Callable: """return encoder embedding activation checkpoint wrapper""" validate_checkpointing_config(activation_checkpointing) if isinstance(activation_checkpointing, str): if activation_checkpointing: if activation_checkpointing == "offload": return offload_wrapper return partial(checkpoint_wrapper) return lambda x: x if isinstance(activation_checkpointing, dict): enabled = activation_checkpointing.get("embed", False) if enabled: offloading = activation_checkpointing.get("offload", False) if offloading: return offload_wrapper impl = ( CheckpointImpl.REENTRANT if activation_checkpointing.get("reentrant", False) else CheckpointImpl.NO_REENTRANT ) return partial(checkpoint_wrapper, checkpoint_impl=impl) return lambda x: x raise ValueError("Invalid activation_checkpointing config") def encoder_checkpoint_wrapper( activation_checkpointing: Union[str, Dict], layer_cls: type, idx: int = 0, ) -> Callable: """return encoder activation checkpoint wrapper""" validate_checkpointing_config(activation_checkpointing) if isinstance(activation_checkpointing, str): if activation_checkpointing: if activation_checkpointing == "offload": return offload_wrapper return partial(checkpoint_wrapper) return lambda x: x if isinstance(activation_checkpointing, dict): target_layer_cls = activation_checkpointing.get("module", "transformer") if target_layer_cls.lower() == "transformer": target_layer_cls = ( "EncoderLayer", "ConformerEncoderLayer", ) elif target_layer_cls.lower() == "attention": target_layer_cls = ("MultiHeadedAttention", "MultiHeadAttention") checkpointing_interval = activation_checkpointing.get("interval", 1) offloading = activation_checkpointing.get("offload", False) impl = ( CheckpointImpl.REENTRANT if activation_checkpointing.get("reentrant", True) else CheckpointImpl.NO_REENTRANT ) if idx % checkpointing_interval == 0 and layer_cls.__name__ in target_layer_cls: if offloading: return offload_wrapper return partial(checkpoint_wrapper, checkpoint_impl=impl) return lambda x: x raise ValueError("Invalid activation_checkpointing config") def attn_checkpointing(activation_checkpointing: Union[str, Dict], i) -> Union[str, Dict]: """return activation checkpointing config for attention layer""" if isinstance(activation_checkpointing, str): return "" if isinstance(activation_checkpointing, dict): target_layer_cls = activation_checkpointing.get("module", "transformer") checkpointing_interval = activation_checkpointing.get("interval", 1) if target_layer_cls == "attention" and i % checkpointing_interval == 0: return activation_checkpointing return "" raise ValueError("Invalid activation_checkpointing config") # utils.py class Block(nn.Module): """Block abstract module""" def __init__(self, input_size, output_size): super().__init__() self.input_size = input_size self.output_size = output_size def get_activation(name="relu"): """Select an activation function by name Args: name: str activation function name, one of ["relu", "gelu", "swish", "sigmoid"], default "relu". """ name = name.lower() if name == "relu": return nn.ReLU(inplace=True) if name == "gelu": return nn.GELU() if name == "swish": return Swish() if name == "sigmoid": return torch.nn.Sigmoid() return nn.Identity() def adaptive_enc_mask(x_len, chunk_start_idx, left_window=0, right_window=0): """ The function is very important for Transformer Transducer Streaming mode Args: xs_len (int): sequence length chunk_start_idx (list): first idx of each chunk, such as [0,18,36,48]. It also supports adaptive chunk size [0,10,15,45] left_window (int): how many left chunks can be seen right_window (int): how many right chunks can be seen. It is used for chunk overlap model. Returns: mask (torch.Tensor): a mask tensor for streaming model Torch 1.0.1 tensor([[1., 1., 0., 0.], [0., 1., 1., 0.], [0., 0., 1., 1.]]) Torch 1.4.1 tensor([[True., True., False., False.], [False., True., True., False.], [False., False., True., True.]]) """ chunk_start_idx = torch.Tensor( chunk_start_idx ).long() # first idx of each chunk, such as [0,18,36,48]. start_pad = torch.nn.functional.pad( chunk_start_idx, (1, 0) ) # append 0 to the beginning, so it becomes [0, 0, 18, 36, 48] end_pad = torch.nn.functional.pad( chunk_start_idx, (0, 1), value=x_len ) # append x_len to the end, so it becomes [0,18,36,48, x_len] seq_range = torch.arange(0, x_len).unsqueeze(-1) # seq_range size: [x_len, 1] idx = ((seq_range < end_pad) & (seq_range >= start_pad)).nonzero()[:, 1] # idx size: [x_len] boundary = end_pad[idx] # boundary size: [x_len] seq_range_expand = ( torch.arange(0, x_len).unsqueeze(0).expand(x_len, -1) ) # seq_range_expand size [x_len, x_len] idx_left = idx - left_window idx_left[idx_left < 0] = 0 boundary_left = start_pad[idx_left] mask_left = seq_range_expand >= boundary_left.unsqueeze(-1) idx_right = idx + right_window idx_right[idx_right > len(chunk_start_idx)] = len(chunk_start_idx) boundary_right = end_pad[idx_right] mask_right = seq_range_expand < boundary_right.unsqueeze(-1) return mask_left & mask_right class Swish(nn.Module): """Implement Swish activation module. From https://arxiv.org/pdf/2005.03191.pdf """ def __init__(self) -> None: super().__init__() self.act_fn = nn.Sigmoid() def forward(self, x: Tensor) -> Tensor: """Apply Swish function Args: x: torch.Tensor Input. """ return x * self.act_fn(x) class GLU(nn.Module): """Implement Gated Linear Unit (GLU) module""" def __init__(self, dim: int = -1, act_name: str = "sigmoid") -> None: super().__init__() self.dim = dim self.act_name = act_name.lower() if self.act_name == "relu": self.act_fn = nn.ReLU(inplace=True) elif self.act_name == "gelu": self.act_fn = nn.GELU() elif self.act_name == "swish": self.act_fn = Swish() elif self.act_name == "sigmoid": self.act_fn = nn.Sigmoid() else: self.act_fn = nn.Identity() def forward(self, x: Tensor) -> Tensor: """GLU forward Apply Swish function on the first half of input matrices with sigmoid of the second half. Args: x: torch.Tensor Input. """ half_x, gate = x.chunk(2, dim=self.dim) return half_x * self.act_fn(gate) # TODO: Abdel, this can be improved using GLU module class GLUPointWiseConv(nn.Module): """GLUPointWiseConv module used for conformer architecture, for more details see: https://arxiv.org/pdf/2005.08100v1.pdf Args: input_dim: int input channel size. output_dim: int output channel size. kernel_size: int kernel size glu_type: str, optional activation function one of ["sigmoid", "relu", "gelu"] default "sigmoid". bias_in_glu: bool, optional use addtive bias in glu causal: bool, optional if set to True, padding is set to the half of kernel size, ie, convolution can't see future frames. default False. """ def __init__( self, input_dim, output_dim, kernel_size, glu_type="sigmoid", bias_in_glu=True, causal=False ): super().__init__() self.glu_type = glu_type self.output_dim = output_dim self.bias_in_glu = bias_in_glu if causal: self.ext_pw_conv_1d = nn.Conv1d( input_dim, output_dim * 2, kernel_size, 1, padding=(kernel_size - 1) ) else: self.ext_pw_conv_1d = nn.Conv1d( input_dim, output_dim * 2, kernel_size, 1, padding=(kernel_size - 1) // 2 ) if glu_type == "sigmoid": self.glu_act = nn.Sigmoid() elif glu_type == "relu": self.glu_act = nn.ReLU() elif glu_type == "gelu": self.glu_act = nn.GELU() elif glu_type == "swish": self.glu_act = Swish() else: raise ValueError(f"Unsupported activation type {self.glu_act}") if bias_in_glu: self.b1 = nn.Parameter(torch.zeros(1, output_dim, 1)) self.b2 = nn.Parameter(torch.zeros(1, output_dim, 1)) def forward(self, x): """ Args: x: torch.Tensor input tensor """ # to be consistent with GLULinear, we assume the input always has the #channel (#dim) in the last dimension of the tensor, so need to switch the dimension first for 1D-Conv case x = x.permute([0, 2, 1]) x = self.ext_pw_conv_1d(x) if self.glu_type == "bilinear": if self.bias_in_glu: x = (x[:, 0 : self.output_dim, :] + self.b1) * ( x[:, self.output_dim : self.output_dim * 2, :] + self.b2 ) else: x = (x[:, 0 : self.output_dim, :]) * ( x[:, self.output_dim : self.output_dim * 2, :] ) else: if self.bias_in_glu: x = (x[:, 0 : self.output_dim, :] + self.b1) * self.glu_act( x[:, self.output_dim : self.output_dim * 2, :] + self.b2 ) else: x = (x[:, 0 : self.output_dim, :]) * self.glu_act( x[:, self.output_dim : self.output_dim * 2, :] ) x = x.permute([0, 2, 1]) return x class DepthWiseSeperableConv1d(nn.Module): """DepthWiseSeperableConv1d module used in Convnet module for the conformer, for more details see: https://arxiv.org/pdf/2005.08100v1.pdf Args: input_dim: int input channel size. depthwise_seperable_out_channel: int if set different to 0, the number of depthwise_seperable_out_channel will be used as a channel_out of the second conv1d layer. otherwise, it equal to 0, the second conv1d layer is skipped. kernel_size: int kernel_size depthwise_multiplier: int number of input_dim channels duplication. this value will be used to compute the hidden channels of the Conv1D. padding: int, optional padding for the conv1d, default: 0. """ def __init__( self, input_dim, depthwise_seperable_out_channel, kernel_size, depthwise_multiplier, padding=0, ): super().__init__() self.dw_conv = nn.Conv1d( input_dim, input_dim * depthwise_multiplier, kernel_size, 1, padding=padding, groups=input_dim, ) if depthwise_seperable_out_channel != 0: self.pw_conv = nn.Conv1d( input_dim * depthwise_multiplier, depthwise_seperable_out_channel, 1, 1, 0 ) else: self.pw_conv = nn.Identity() self.depthwise_seperable_out_channel = depthwise_seperable_out_channel def forward(self, x): """ Args: x: torch.Tensor input tensor """ x = self.dw_conv(x) if self.depthwise_seperable_out_channel != 0: x = self.pw_conv(x) return x class ConvModule(nn.Module): """ConvModule Module for the conformer block. for more details see: https://arxiv.org/pdf/2005.08100v1.pdf Args: input_dim: int input channel size. ext_pw_out_channel: int if > 0, ext_pw_out_channel is a dim channel size for the last pointwise conv after swish activation. depthwise_seperable_out_channel: int if set different to 0, the number of depthwise_seperable_out_channel will be used as a channel_out of the second conv1d layer. otherwise, it equal to 0, the second conv1d layer is skipped. ext_pw_kernel_size: int kernel size of the conv pointwise of the conformer. kernel_size: int kernel size. depthwise_multiplier: int number of input_dim channels duplication. this value will be used to compute the hidden channels of the Conv1D. dropout_rate: float dropout rate. causal: bool, optional if set to True, convolution have no access to future frames. default False. batch_norm: bool, optional if set to True, apply batchnorm before activation. default False chunk_se: int, optional 0 for offline SE. 1 for streaming SE, where mean is computed by accumulated history until current chunk_se. 2 for streaming SE, where mean is computed by only the current chunk. chunk_size: int, optional chunk size for cnn. default 18 activation: str, optional activation function used in ConvModule, default: "relu". glu_type: str, optional activation function used for the glu, default: "sigmoid". bias_in_glu: bool, optional if set to True, use additive bias in the weight module before GLU. linear_glu_in_convm: bool, optional if set to True, use GLULinear module, otherwise, used GLUPointWiseConv module. default to False. export: bool, optional, if set to True, padding is equal to 0. This is for inference, or onnx export. Typically this is set by the export program or the decoder program, and it isn't present in your config file. default False """ def __init__( self, input_dim, ext_pw_out_channel, depthwise_seperable_out_channel, ext_pw_kernel_size, kernel_size, depthwise_multiplier, dropout_rate, causal=False, batch_norm=False, chunk_se=0, chunk_size=18, activation="relu", glu_type="sigmoid", bias_in_glu=True, linear_glu_in_convm=False, export=False, ): super().__init__() self.layer_norm = nn.LayerNorm(input_dim) self.input_dim = input_dim self.ext_pw_out_channel = ext_pw_out_channel self.ext_pw_kernel_size = ext_pw_kernel_size self.depthwise_seperable_out_channel = depthwise_seperable_out_channel self.glu_type = glu_type self.bias_in_glu = bias_in_glu self.linear_glu_in_convm = linear_glu_in_convm self.causal = causal self._add_ext_pw_layer() self.batch_norm = batch_norm self.kernel_size = kernel_size if batch_norm: self.bn_layer = nn.BatchNorm1d(input_dim) self.act = get_activation(activation) self.dropout = nn.Dropout(dropout_rate) self.export = export if causal: if export: # Inference only. padding = 0 # A cache is concatenated to the left. No padding in the kernel. else: # Training only. Padding will be added symmetrically on both sides. # After convolution, clip off kernel_size-1 points on the right. padding = kernel_size - 1 else: padding = (kernel_size - 1) // 2 self.dw_sep_conv_1d = DepthWiseSeperableConv1d( input_dim, depthwise_seperable_out_channel, kernel_size, depthwise_multiplier, padding=padding, ) if depthwise_seperable_out_channel != 0: if input_dim != depthwise_seperable_out_channel: self.ln2 = nn.Linear(depthwise_seperable_out_channel, input_dim) else: if depthwise_multiplier != 1: self.ln2 = nn.Linear(input_dim * depthwise_multiplier, input_dim) def _add_ext_pw_layer(self): """ This function is an extension of __init__ function and dedicated to the convolution module creation of the conformer. """ self.ln1 = self.glu = self.bn_layer = self.ext_pw_conv_1d = nn.Identity() # jit hacks. self.squeeze_excitation = nn.Identity() # jit. self.apply_ln1 = self.fix_len1 = False # jit. if self.ext_pw_out_channel != 0: if self.causal: self.ext_pw_conv_1d = nn.Conv1d( self.input_dim, self.ext_pw_out_channel, self.ext_pw_kernel_size, 1, padding=(self.ext_pw_kernel_size - 1), ) if self.ext_pw_kernel_size > 1: self.fix_len1 = True else: self.fix_len1 = False else: self.ext_pw_conv_1d = nn.Conv1d( self.input_dim, self.ext_pw_out_channel, self.ext_pw_kernel_size, 1, padding=(self.ext_pw_kernel_size - 1) // 2, ) self.fix_len1 = False if self.linear_glu_in_convm: self.glu = GLULinear( self.input_dim, self.ext_pw_out_channel, self.glu_type, self.bias_in_glu ) else: self.glu = GLUPointWiseConv( self.input_dim, self.ext_pw_out_channel, self.ext_pw_kernel_size, self.glu_type, self.bias_in_glu, self.causal, ) if self.input_dim != self.ext_pw_out_channel: self.apply_ln1 = True self.ln1 = nn.Linear(self.ext_pw_out_channel, self.input_dim) else: self.apply_ln1 = False else: self.pw_conv_simplify_w = torch.nn.Parameter(torch.ones(3)) self.pw_conv_simplify_b = torch.nn.Parameter(torch.zeros(3)) def forward(self, x): """ConvModule Forward. Args: x: torch.Tensor input tensor. """ x = self.layer_norm(x) if self.ext_pw_out_channel != 0: x = self.glu(x) if self.causal and self.ext_pw_kernel_size > 1: x = x[:, : -(self.ext_pw_kernel_size - 1), :] if self.apply_ln1: x = self.ln1(x) else: x_0 = x * self.pw_conv_simplify_w[0] + self.pw_conv_simplify_b[0] x_1 = x * self.pw_conv_simplify_w[1] + self.pw_conv_simplify_b[1] x = x_0 + x_1 x = x.permute([0, 2, 1]) x = self.dw_sep_conv_1d(x) if self.causal and self.kernel_size > 1: x = x[:, :, : -(self.kernel_size - 1)] if hasattr(self, "ln2"): x = x.permute([0, 2, 1]) x = self.ln2(x) x = x.permute([0, 2, 1]) if self.batch_norm: x = self.bn_layer(x) x = self.act(x) if self.ext_pw_out_channel != 0: x = self.ext_pw_conv_1d(x) if self.fix_len1: x = x[:, :, : -(self.ext_pw_kernel_size - 1)] if self.apply_ln1: x = x.permute([0, 2, 1]) x = self.ln1(x) x = x.permute([0, 2, 1]) x = x.permute([0, 2, 1]) else: x = x.unsqueeze(1).permute([0, 1, 3, 2]) x = x * self.pw_conv_simplify_w[2] + self.pw_conv_simplify_b[2] x = x.squeeze(1) x = self.dropout(x) return x class GLULinear(nn.Module): """Linear + GLU module Args: input_dim: int input size output_dim: int output size. glu_type: activation function name used in glu module. default "sigmoid" (swish function). bias_in_glu: bool, optional If True, the addtive bias is added. Default False. """ def __init__( self, input_dim, output_dim, glu_type="sigmoid", bias_in_glu=True, ): super().__init__() self.linear = nn.Linear(input_dim, output_dim * 2, bias_in_glu) self.glu_act = GLU(-1, glu_type) def forward(self, x): """GLULinear forward Args: x: torch.Tensor inpute tensor. """ x = self.linear(x) return self.glu_act(x) class FeedForward(nn.Module): """FeedForward Module. For more details see Conformer paper: https://arxiv.org/pdf/2005.08100.pdf Args: d_model: int input size. d_inner: int output size. dropout_rate: float, dropout rate. activation: str, activation function name, one of ["relu", "swish", "sigmoid"], sigmoid activation is only used with "glu_in_fnn=True", default "sigmoid". bias_in_glu: bool, optional """ def __init__( self, d_model, d_inner, dropout_rate, activation="sigmoid", bias_in_glu=True, ): super().__init__() self.d_model = d_model self.d_inner = d_inner self.layer_norm = nn.LayerNorm(d_model) module = GLULinear(d_model, d_inner, activation, bias_in_glu) self.net = nn.Sequential( module, nn.Dropout(dropout_rate), nn.Linear(d_inner, d_model), nn.Dropout(dropout_rate), ) def forward(self, x): """FeedForward forward function. Args: x: torch.Tensor input tensor. """ out = self.net(self.layer_norm(x)) return out #### positional encoding starts here def _pre_hook( state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs ): """Perform pre-hook in load_state_dict for backward compatibility. Note: We saved self.pe until v.0.5.2 but we have omitted it later. Therefore, we remove the item "pe" from `state_dict` for backward compatibility. """ k = prefix + "pe" if k in state_dict: state_dict.pop(k) class T5RelativeAttentionLogitBias(nn.Module): """ This module implements the relative position bias described in Section 2.1 of the T5 paper: https://arxiv.org/pdf/1910.10683.pdf The Huggingface implementation is used as a reference https://github.com/huggingface/transformers/blob/v4.30.0/src/transformers/models/t5/modeling_t5.py#L435 Modifies attention as Q*K^T + B, where B is a learned scalar bias based on relative position of the query and key. It is HxNxN, where H is the number of heads, N is the sequence length. I've made these modifications to the original T5 bias: - Skipping of the bucketing step. Original T5 bias converted rel position distances into logarithmically increasing buckets. This is supposed to help with length generalization. - I just directly use rel position index as bias values, as we don't need length generalization (40s max is good enough for ASR encoder), and it keeps ONNX export simple. - I've also extended it so that biases can be asymmetric, the default implementation treats L->R and R->L the same. Asymmetric was found to yield better results in my experiments. Args: num_heads: int Number of attention heads num_buckets: int Number of buckets to use for relative attention bias. This is the size of the learnable bias parameter. Bucketing is not yet supported, so this defaults to -1 which means no bucketing is used (max_distance determines size of bias param). max_distance: int Maximum distance to use for relative attention bias. With num_buckets=-1, this directly controls the max size of the bias parameter. When num_buckets > 0 is supported, this will control the maximum distance for logarithmic bucketing after which all positions are in the same bucket. symmetric: bool Whether to use symmetric or asymmetric biases. symmetric=False uses 2x number of bias params to distinguish L->R from R->L. This was found to be better for the encoder. """ def __init__(self, num_heads, num_buckets=-1, max_distance=1000, symmetric=False): super().__init__() self.num_heads = num_heads self.num_buckets = num_buckets self.max_distance = max_distance self.symmetric = symmetric self._skip_bucketing = self.num_buckets < 0 if self._skip_bucketing: self.num_buckets = max_distance else: raise NotImplementedError("T5 attention bias with bucketed positions is not yet tested") if not self.symmetric: self.num_buckets *= 2 self.bias_values = nn.Embedding(self.num_buckets, self.num_heads) def forward(self, x): # instantiate bias compatible with shape of x maxpos = x.size(1) context_position = torch.arange(maxpos, device=x.device, dtype=torch.long)[:, None] memory_position = torch.arange(maxpos, device=x.device, dtype=torch.long)[None, :] relative_position = memory_position - context_position # clipping to a maximum distance using ops that play well with ONNX export relative_position = relative_position.masked_fill( relative_position < -self.max_distance, -self.max_distance ) relative_position = relative_position.masked_fill( relative_position > self.max_distance - 1, self.max_distance - 1 ) # mapping from relative position to index in the bias parameter if self._skip_bucketing: bias_idx = relative_position else: bias_idx = self._bucket_relative_position(relative_position) if self.symmetric: bias_idx = bias_idx.abs() else: bias_idx += self.num_buckets // 2 t5_rel_att_bias = self.bias_values(bias_idx) # [L, L, H] t5_rel_att_bias = t5_rel_att_bias.permute(2, 0, 1).unsqueeze(0) # [1, H, L, L] return t5_rel_att_bias def _bucket_relative_position(self, relative_position): # this is a placeholder (isn't tested, likely buggy) using HuggingFace implem as a reference # this also needs to be extended to support asymmetric +/- ve positions relative_buckets = 0 if not self.causal: num_buckets //= 2 relative_buckets += (relative_position > 0).to(torch.long) * num_buckets relative_position = torch.abs(relative_position) else: relative_position = -torch.min(relative_position, torch.zeros_like(relative_position)) # now relative_position is in the range [0, inf) # half of the buckets are for exact increments in positions max_exact = num_buckets // 2 is_small = relative_position < max_exact # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance relative_position_if_large = max_exact + ( torch.log(relative_position.float() / max_exact) / math.log(self.max_distance / max_exact) * (num_buckets - max_exact) ).to(torch.long) relative_position_if_large = torch.min( relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1) ) relative_buckets += torch.where(is_small, relative_position, relative_position_if_large) return relative_buckets class AbsolutePositionalEncoding(nn.Module): """Absolute Positional encoding module. This module implement Absolute sinusoidal positional encoding from: https://arxiv.org/pdf/1706.03762.pdf Args: d_model: int Input embedding size. dropout_rate: float dropout rate max_len: int, optional Maximum input length sequence, Default 5000 """ def __init__(self, d_model, dropout_rate, max_len=5000): """Construct an PositionalEncoding object.""" super().__init__() self.d_model = d_model self.xscale = math.sqrt(self.d_model) self.dropout = torch.nn.Dropout(p=dropout_rate) self.pe = None self.extend_pe(torch.tensor(0.0).expand(1, max_len)) self._register_load_state_dict_pre_hook(_pre_hook) def extend_pe(self, x): """Reset the positional encodings. Args: x: torch.Tensor """ if self.pe is not None: if self.pe.size(1) >= x.size(1): if self.pe.dtype != x.dtype or self.pe.device != x.device: self.pe = self.pe.to(dtype=x.dtype, device=x.device) return pe = torch.zeros(x.size(1), self.d_model) position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1) div_term = torch.exp( torch.arange(0, self.d_model, 2, dtype=torch.float32) * -(math.log(10000.0) / self.d_model) ) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) pe = pe.unsqueeze(0) self.pe = pe.to(device=x.device, dtype=x.dtype) def forward(self, x: torch.Tensor): """Add positional encoding. Args: x: torch.Tensor Input tensor. shape is (batch, time, ...) Returns: torch.Tensor: Encoded tensor. Its shape is (batch, time, ...) """ self.extend_pe(x) x = x * self.xscale + self.pe[:, : x.size(1)] return self.dropout(x) #### forward embedding layers starts here @backoff.on_exception(backoff.expo, Exception, max_tries=10) def np_loadtxt_with_retry(filepath): """np.loadtxt with retry Args: filepath: str file path to the numpy array. """ result = np.loadtxt(filepath, dtype="f") return result class MeanVarianceNormLayer(nn.Module): """Mean/variance normalization layer. Will substract mean and multiply input by inverted standard deviation. Typically used as a very first layer in a model. Args: input_size: int layer input size. """ def __init__(self, input_size): super().__init__() self.input_size = input_size self.register_buffer("global_mean", torch.zeros(input_size)) self.register_buffer("global_invstd", torch.ones(input_size)) self.global_mean: Optional[Tensor] self.global_invstd: Optional[Tensor] def forward(self, input_: Tensor) -> Tensor: """MeanVarianceNormLayer Forward Args: input_: torch.Tensor input tensor. """ return (input_ - self.global_mean) * self.global_invstd def load_mean_invstd(self, mean_file, invstd_file, cuside_features=False): """Load feature mean and variance used for normalization. Args: mean_file: str path to the feature mean statistics file. invstd_file: str path to the features inverted standard deviation statistics file. cuside_features: bool Boolean that indicates CUSIDE is being used. The statistics of CUSIDE features are copied from the normal features """ self.global_mean.data = torch.from_numpy(np_loadtxt_with_retry(mean_file)) self.global_invstd.data = torch.from_numpy(np_loadtxt_with_retry(invstd_file)) if cuside_features: self.global_mean.data = torch.cat((self.global_mean.data, self.global_mean.data), 0) self.global_invstd.data = torch.cat( (self.global_invstd.data, self.global_invstd.data), 0 ) class CausalConv1D(nn.Conv1d): """ A causal version of nn.Conv1d where each step would have limited access to locations on its right or left All arguments are the same as nn.Conv1d except padding. If padding is set None, then paddings are set automatically to make it a causal convolution where each location would not see any steps on its right. If padding is set as a list (size of 2), then padding[0] would be used as left padding and padding[1] as right padding. It would make it possible to control the number of steps to be accessible on the right and left. This mode is not supported when stride > 1. padding[0]+padding[1] should be equal to (kernel_size - 1). """ def __init__( self, in_channels: int, out_channels: int, kernel_size: int, stride: int = 1, padding: Union[str, int] = 0, dilation: int = 1, groups: int = 1, bias: bool = True, padding_mode: str = "zeros", device=None, dtype=None, ) -> None: self.cache_drop_size = None if padding is None: self._left_padding = kernel_size - 1 self._right_padding = stride - 1 else: if stride != 1 and padding != kernel_size - 1: raise ValueError("No striding allowed for non-symmetric convolutions!") if isinstance(padding, int): self._left_padding = padding self._right_padding = padding elif ( isinstance(padding, list) and len(padding) == 2 and padding[0] + padding[1] == kernel_size - 1 ): self._left_padding = padding[0] self._right_padding = padding[1] else: raise ValueError(f"Invalid padding param: {padding}!") self._max_cache_len = self._left_padding super().__init__( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=0, dilation=dilation, groups=groups, bias=bias, padding_mode=padding_mode, device=device, dtype=dtype, ) def update_cache(self, x, cache=None): if cache is None: new_x = F.pad(x, pad=(self._left_padding, self._right_padding)) next_cache = cache else: new_x = F.pad(x, pad=(0, self._right_padding)) new_x = torch.cat([cache, new_x], dim=-1) if self.cache_drop_size > 0: next_cache = new_x[:, :, : -self.cache_drop_size] else: next_cache = new_x next_cache = next_cache[:, :, -cache.size(-1) :] return new_x, next_cache def forward(self, x, cache=None): x, cache = self.update_cache(x, cache=cache) x = super().forward(x) if cache is None: return x else: return x, cache class CausalConv2D(nn.Conv2d): """ A causal version of nn.Conv2d where each location in the 2D matrix would have no access to locations on its right or down All arguments are the same as nn.Conv2d except padding which should be set as None """ def __init__( self, in_channels: int, out_channels: int, kernel_size: int, stride: int = 1, padding: Union[str, int] = 0, dilation: int = 1, groups: int = 1, bias: bool = True, padding_mode: str = "zeros", device=None, dtype=None, ) -> None: if padding is not None: raise ValueError("Argument padding should be set to None for CausalConv2D.") self._left_padding = kernel_size - 1 self._right_padding = stride - 1 padding = 0 super().__init__( in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode, device, dtype, ) def forward( self, x, ): if self.training: x = F.pad( x, pad=( self._left_padding, self._right_padding, self._left_padding, self._right_padding, ), ) else: x = F.pad( x, pad=(self._left_padding, self._right_padding, 0, 0), ) x = super().forward(x) return x class NemoConvSubsampling(torch.nn.Module): """Convlutional subsampling module, taken from NeMo ASR (https://github.com/NVIDIA/NeMo/blob/b367413645d5c72db3c2c96e46e95a34501479cf/nemo/collections/asr/parts/submodules/subsampling.py) Striding Subsampling: "Speech-Transformer: A No-Recurrence Sequence-to-Sequence Model for Speech Recognition" by Linhao Dong et al. (https://ieeexplore.ieee.org/document/8462506) Compared with the EncoderConv2D (`input_layer: custom`), this is a much simplified approach, and uses no LayerNorm and far fewer Conv2Ds. Moreover, depthwise convolutions are used to reduce FLOPs, but the first layer is kept as a regular convolution so as not to degrade accuracy. `Striding` and `dw_striding` are the same except that the latter uses depthwise convolutions after the first layer, whereas the former does not. Args: subsampling_factor (int): Time reduction factor feat_in (int): size of the input features feat_out (int): size of the output features subsampling (str): The subsampling technique, choose from {"striding", "dw-striding", "striding_conv1d", "dw_striding_conv1d"} conv_channels (int): Number of channels for the convolution layers, default is 256. subsampling_conv_chunking_factor (int): Input chunking factor which can be -1 (no chunking) 1 (auto) or a power of 2. Default is 1 activation (Module): activation function, default is nn.ReLU() is_causal (bool): whether to use causal Conv1/2D, where each step will have limited access to locations on its right or left """ def __init__( self, feat_in, feat_out, subsampling_factor=4, subsampling="dw_striding", conv_channels=256, subsampling_conv_chunking_factor=1, activation=nn.ReLU(), is_causal=False, ): super().__init__() self._subsampling = subsampling self._conv_channels = conv_channels self._feat_in = feat_in self._feat_out = feat_out if subsampling_factor % 2 != 0: raise ValueError("Sampling factor should be a multiply of 2!") self._sampling_num = int(math.log(subsampling_factor, 2)) self.subsampling_factor = subsampling_factor self.is_causal = is_causal self.subsampling_causal_cond = subsampling in ("dw_striding", "striding", "striding_conv1d") if ( subsampling_conv_chunking_factor != -1 and subsampling_conv_chunking_factor != 1 and subsampling_conv_chunking_factor % 2 != 0 ): raise ValueError("subsampling_conv_chunking_factor should be -1, 1, or a power of 2") self.subsampling_conv_chunking_factor = subsampling_conv_chunking_factor in_channels = 1 layers = [] if subsampling == "dw_striding": self._stride = 2 self._kernel_size = 3 self._ceil_mode = False if self.is_causal: self._left_padding = self._kernel_size - 1 self._right_padding = self._stride - 1 self._max_cache_len = subsampling_factor + 1 else: self._left_padding = (self._kernel_size - 1) // 2 self._right_padding = (self._kernel_size - 1) // 2 self._max_cache_len = 0 # Layer 1 if self.is_causal: layers.append( CausalConv2D( in_channels=in_channels, out_channels=conv_channels, kernel_size=self._kernel_size, stride=self._stride, padding=None, ) ) else: layers.append( torch.nn.Conv2d( in_channels=in_channels, out_channels=conv_channels, kernel_size=self._kernel_size, stride=self._stride, padding=self._left_padding, ) ) in_channels = conv_channels layers.append(activation) for i in range(self._sampling_num - 1): if self.is_causal: layers.append( CausalConv2D( in_channels=in_channels, out_channels=in_channels, kernel_size=self._kernel_size, stride=self._stride, padding=None, groups=in_channels, ) ) else: layers.append( torch.nn.Conv2d( in_channels=in_channels, out_channels=in_channels, kernel_size=self._kernel_size, stride=self._stride, padding=self._left_padding, groups=in_channels, ) ) layers.append( torch.nn.Conv2d( in_channels=in_channels, out_channels=conv_channels, kernel_size=1, stride=1, padding=0, groups=1, ) ) layers.append(activation) in_channels = conv_channels elif subsampling == "striding": self._stride = 2 self._kernel_size = 3 self._ceil_mode = False if self.is_causal: self._left_padding = self._kernel_size - 1 self._right_padding = self._stride - 1 self._max_cache_len = subsampling_factor + 1 else: self._left_padding = (self._kernel_size - 1) // 2 self._right_padding = (self._kernel_size - 1) // 2 self._max_cache_len = 0 for i in range(self._sampling_num): if self.is_causal: layers.append( CausalConv2D( in_channels=in_channels, out_channels=conv_channels, kernel_size=self._kernel_size, stride=self._stride, padding=None, ) ) else: layers.append( torch.nn.Conv2d( in_channels=in_channels, out_channels=conv_channels, kernel_size=self._kernel_size, stride=self._stride, padding=self._left_padding, ) ) layers.append(activation) in_channels = conv_channels elif subsampling == "striding_conv1d": in_channels = feat_in self._stride = 2 self._kernel_size = 5 self._ceil_mode = False if self.is_causal: self._left_padding = self._kernel_size - 1 self._right_padding = self._stride - 1 self._max_cache_len = subsampling_factor + 1 else: self._left_padding = (self._kernel_size - 1) // 2 self._right_padding = (self._kernel_size - 1) // 2 self._max_cache_len = 0 for i in range(self._sampling_num): if self.is_causal: layers.append( CausalConv1D( in_channels=in_channels, out_channels=feat_out if self._sampling_num == i + 1 else conv_channels, kernel_size=self._kernel_size, stride=self._stride, padding=None, ) ) else: layers.append( torch.nn.Conv1d( in_channels=in_channels, out_channels=feat_out if self._sampling_num == i + 1 else conv_channels, kernel_size=self._kernel_size, stride=self._stride, padding=self._left_padding, ) ) layers.append(activation) in_channels = conv_channels elif subsampling == "dw_striding_conv1d": in_channels = feat_in self._stride = 2 self._kernel_size = 5 self._ceil_mode = False self._left_padding = (self._kernel_size - 1) // 2 self._right_padding = (self._kernel_size - 1) // 2 # Layer 1 layers.extend( [ torch.nn.Conv1d( in_channels=in_channels, out_channels=in_channels, kernel_size=self._kernel_size, stride=self._stride, padding=self._left_padding, groups=in_channels, ), torch.nn.Conv1d( in_channels=in_channels, out_channels=feat_out if self._sampling_num == 1 else conv_channels, kernel_size=1, stride=1, padding=0, groups=1, ), ] ) in_channels = conv_channels layers.append(activation) for i in range(self._sampling_num - 1): layers.extend( [ torch.nn.Conv1d( in_channels=in_channels, out_channels=in_channels, kernel_size=self._kernel_size, stride=self._stride, padding=self._left_padding, groups=in_channels, ), torch.nn.Conv1d( in_channels=in_channels, out_channels=feat_out if self._sampling_num == i + 2 else conv_channels, kernel_size=1, stride=1, padding=0, groups=1, ), ] ) layers.append(activation) in_channels = conv_channels else: raise ValueError(f"Not valid sub-sampling: {subsampling}!") if subsampling in ["dw_striding", "striding"]: in_length = torch.tensor(feat_in, dtype=torch.float) out_length = calc_length( lengths=in_length, all_paddings=self._left_padding + self._right_padding, kernel_size=self._kernel_size, stride=self._stride, ceil_mode=self._ceil_mode, repeat_num=self._sampling_num, ) self.out = torch.nn.Linear(conv_channels * int(out_length), feat_out) self.conv2d_subsampling = True elif subsampling in ["striding_conv1d", "dw_striding_conv1d"]: self.out = None self.conv2d_subsampling = False else: raise ValueError(f"Not valid sub-sampling: {subsampling}!") self.conv = torch.nn.Sequential(*layers) def get_sampling_frames(self): return [1, self.subsampling_factor] def get_streaming_cache_size(self): return [0, self.subsampling_factor + 1] def forward(self, x, mask): """ Forward method for NeMo subsampling. Args: x[Batch, Time, Filters]: torch.Tensor input tensor x_mask: torch.Tensor input mask Returns: x: torch.Tensor Resulting tensor from subsampling (B, T // time_reduction_factor, feat_out) pad_mask: torch.Tensor tensor of padded hidden state sequences (B, 1, T // time_reduction_factor) """ # Unsqueeze Channel Axis if self.conv2d_subsampling: x = x.unsqueeze(1) # Transpose to Channel First mode else: x = x.transpose(1, 2) # split inputs if chunking_factor is set if self.subsampling_conv_chunking_factor != -1 and self.conv2d_subsampling: if self.subsampling_conv_chunking_factor == 1: # if subsampling_conv_chunking_factor is 1, we split only if needed # avoiding a bug / feature limiting indexing of tensors to 2**31 # see https://github.com/pytorch/pytorch/issues/80020 x_ceil = 2**31 / self._conv_channels * self._stride * self._stride if torch.numel(x) > x_ceil: need_to_split = True else: need_to_split = False else: # if subsampling_conv_chunking_factor > 1 we always split need_to_split = True if need_to_split: x, success = self.conv_split_by_batch(x) if not success: # if unable to split by batch, try by channel if self._subsampling == "dw_striding": x = self.conv_split_by_channel(x) else: x = self.conv(x) # try anyway else: x = self.conv(x) else: x = self.conv(x) # Flatten Channel and Frequency Axes if self.conv2d_subsampling: b, c, t, f = x.size() x = self.out(x.transpose(1, 2).reshape(b, t, -1)) # Transpose to Channel Last mode else: x = x.transpose(1, 2) if mask is None: return x, None max_audio_length = x.shape[1] feature_lens = mask.sum(1) padding_length = torch.ceil(feature_lens / self.subsampling_factor) if self.is_causal and self.subsampling_causal_cond: feature_lens_remainder = feature_lens % self.subsampling_factor padding_length[feature_lens_remainder != 1] += 1 pad_mask = ( torch.arange(0, max_audio_length, device=x.device).expand(padding_length.size(0), -1) < padding_length.unsqueeze(1) ) return x, pad_mask.unsqueeze(1) def reset_parameters(self): # initialize weights if self._subsampling == "dw_striding": with torch.no_grad(): # init conv scale = 1.0 / self._kernel_size dw_max = (self._kernel_size**2) ** -0.5 pw_max = self._conv_channels**-0.5 torch.nn.init.uniform_(self.conv[0].weight, -scale, scale) torch.nn.init.uniform_(self.conv[0].bias, -scale, scale) for idx in range(2, len(self.conv), 3): torch.nn.init.uniform_(self.conv[idx].weight, -dw_max, dw_max) torch.nn.init.uniform_(self.conv[idx].bias, -dw_max, dw_max) torch.nn.init.uniform_(self.conv[idx + 1].weight, -pw_max, pw_max) torch.nn.init.uniform_(self.conv[idx + 1].bias, -pw_max, pw_max) # init fc (80 * 64 = 5120 from https://github.com/kssteven418/Squeezeformer/blob/13c97d6cf92f2844d2cb3142b4c5bfa9ad1a8951/src/models/conformer_encoder.py#L487 fc_scale = (self._feat_out * self._feat_in / self._sampling_num) ** -0.5 torch.nn.init.uniform_(self.out.weight, -fc_scale, fc_scale) torch.nn.init.uniform_(self.out.bias, -fc_scale, fc_scale) def conv_split_by_batch(self, x): """Tries to split input by batch, run conv and concat results""" b, _, _, _ = x.size() if b == 1: # can't split if batch size is 1 return x, False if self.subsampling_conv_chunking_factor > 1: cf = self.subsampling_conv_chunking_factor else: # avoiding a bug / feature limiting indexing of tensors to 2**31 # see https://github.com/pytorch/pytorch/issues/80020 x_ceil = 2**31 / self._conv_channels * self._stride * self._stride p = math.ceil(math.log(torch.numel(x) / x_ceil, 2)) cf = 2**p new_batch_size = b // cf if new_batch_size == 0: # input is too big return x, False return torch.cat([self.conv(chunk) for chunk in torch.split(x, new_batch_size, 0)]), True def conv_split_by_channel(self, x): """For dw convs, tries to split input by time, run conv and concat results""" x = self.conv[0](x) # full conv2D x = self.conv[1](x) # activation for i in range(self._sampling_num - 1): _, c, t, _ = x.size() if self.subsampling_conv_chunking_factor > 1: cf = self.subsampling_conv_chunking_factor else: # avoiding a bug / feature limiting indexing of tensors to 2**31 # see https://github.com/pytorch/pytorch/issues/80020 p = math.ceil(math.log(torch.numel(x) / 2**31, 2)) cf = 2**p new_c = int(c // cf) if new_c == 0: new_c = 1 new_t = int(t // cf) if new_t == 0: new_t = 1 x = self.channel_chunked_conv(self.conv[i * 3 + 2], new_c, x) # conv2D, depthwise # splitting pointwise convs by time x = torch.cat( [self.conv[i * 3 + 3](chunk) for chunk in torch.split(x, new_t, 2)], 2 ) # conv2D, pointwise x = self.conv[i * 3 + 4](x) # activation return x def channel_chunked_conv(self, conv, chunk_size, x): """Performs channel chunked convolution""" ind = 0 out_chunks = [] for chunk in torch.split(x, chunk_size, 1): step = chunk.size()[1] if self.is_causal: chunk = nn.functional.pad( chunk, pad=( self._kernel_size - 1, self._stride - 1, self._kernel_size - 1, self._stride - 1, ), ) ch_out = nn.functional.conv2d( chunk, conv.weight[ind : ind + step, :, :, :], bias=conv.bias[ind : ind + step], stride=self._stride, padding=0, groups=step, ) else: ch_out = nn.functional.conv2d( chunk, conv.weight[ind : ind + step, :, :, :], bias=conv.bias[ind : ind + step], stride=self._stride, padding=self._left_padding, groups=step, ) out_chunks.append(ch_out) ind += step return torch.cat(out_chunks, 1) def change_subsampling_conv_chunking_factor(self, subsampling_conv_chunking_factor: int): if ( subsampling_conv_chunking_factor != -1 and subsampling_conv_chunking_factor != 1 and subsampling_conv_chunking_factor % 2 != 0 ): raise ValueError("subsampling_conv_chunking_factor should be -1, 1, or a power of 2") self.subsampling_conv_chunking_factor = subsampling_conv_chunking_factor def calc_length(lengths, all_paddings, kernel_size, stride, ceil_mode, repeat_num=1): """Calculates the output length of a Tensor passed through a convolution or max pooling layer""" add_pad: float = all_paddings - kernel_size one: float = 1.0 for i in range(repeat_num): lengths = torch.div(lengths.to(dtype=torch.float) + add_pad, stride) + one if ceil_mode: lengths = torch.ceil(lengths) else: lengths = torch.floor(lengths) return lengths.to(dtype=torch.int) #### multihead attention starts here class AttModule(nn.Module): """Attention abstraction module""" def __init__(self): super().__init__() self.export_mode = False def set_export(self, mode=True): """set the export mode""" self.export_mode = mode def forward( self, x: Tensor, memory: Optional[Tensor] = None, pos_emb: Optional[Tensor] = None, att_mask: Optional[Tensor] = None, ) -> Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]]: """AttModule forward Args: x: torch.Tensor input tensor. memory: torch.Tensor, optional memory tensor. pos_emb: torch.Tensor, optional positional encoder embedding. att_mask: torch.Tensor, optional attention mask tensor. """ return x, memory, pos_emb, att_mask class AttBlock(Block, AttModule): """Attention Block module to support both Attention and Block module.""" def memory_dims(self, max_len=False): """memory dimensions""" return (1, self.input_size) def masked_softmax( scores, mask: Optional[Tensor], ): if mask is not None: mask = mask.unsqueeze(1).eq(0) # (batch, 1, time1, time2) scores = scores.masked_fill(mask, -torch.inf) attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0) # (batch, head, time1, time2) else: attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2) return attn class MultiHeadedAttention(nn.Module): """Multi-Head Attention layer with optional relative position embedding and GLU. Args: n_head: int the number of heads. n_feat: int input size features. dropout_rate: float dropout rate. use_LN: bool apply layer norm or not dropout_at_output: bool whether to apply dropout at output attention_inner_dim: int, optional the attention dimension used in the class, it can be different from the input dimension n_feat. default: -1 (equal to n_feat). use_pt_scaled_dot_product_attention: bool, optional if set True, use pytorch scaled dot product attention in training. NOTE: this will NOT be used in ONNX decoding due to a lack of support. In that case, we use the original attention implementation, which shows no regression. default: False. n_value: int, optional if set to values other than -1, use a different dimension for value. With the default value (i.e. -1), it is backward compatible. group_size: int, optional. must divide `n_head` if group_size > 1: GQA if group_size = 1: MHA if group_size = n_head: MQA """ inv_sqrt_d_k: torch.jit.Final[float] h: torch.jit.Final[int] h_k: torch.jit.Final[int] g: torch.jit.Final[int] def __init__( self, n_head, n_feat, dropout_rate, attention_inner_dim=-1, glu_type="swish", bias_in_glu=True, use_pt_scaled_dot_product_attention=False, n_value=-1, group_size: int = 1, ): super().__init__() if n_value == -1: n_value = n_feat if attention_inner_dim == -1: attention_inner_dim = n_feat assert attention_inner_dim % n_head == 0 # We assume d_v always equals d_k self.d_k = attention_inner_dim // n_head self.inv_sqrt_d_k = 1.0 / math.sqrt(self.d_k) self.h = n_head assert n_head % group_size == 0, "group_size must divide n_head" self.g = group_size self.h_k = n_head // group_size self.linear_q = nn.Linear(n_feat, attention_inner_dim) self.linear_k = nn.Linear(n_feat, attention_inner_dim // group_size) self.linear_v = nn.Linear(n_value, attention_inner_dim // group_size) self.linear_out = nn.Linear(attention_inner_dim // group_size, n_value) self.attn = torch.jit.Attribute(None, Optional[Tensor]) self.dropout = nn.Dropout(p=dropout_rate) self.dropout_rate = dropout_rate self.use_pt_scaled_dot_product_attention = use_pt_scaled_dot_product_attention if use_pt_scaled_dot_product_attention and group_size > 1: raise ValueError("Cannot use PT Scaled Attention with GQA") # Torchscript eager quantization. Note that these functions below are # NOOPs and have very little impact on performance unless quantization is # enabled. self.quant_q = torch.ao.quantization.QuantStub() self.quant_x = torch.ao.quantization.QuantStub() self.dequant = torch.ao.quantization.DeQuantStub() self.ffunc = torch.ao.nn.quantized.FloatFunctional() def forward( self, query: Tensor, key: Tensor, value: Tensor, pos_k: Tensor, pos_v: Tensor, mask: Optional[Tensor], relative_attention_bias: Optional[Tensor] = None, ): """Compute 'Scaled Dot Product Attention'. Args: query: torch.Tensor query tensor (batch, time1, size) key: torch.Tensor key tensor (batch, time2, size) value: torch.Tensor value tensor (batch, time1, size) pos_k: torch.Tensor key tensor used for relative positional embedding. pos_v: torch.Tensor value tensor used for relative positional embedding. mask: torch.Tensor mask tensor (batch, time1, time2) relative_attention_bias: torch.Tensor bias added to attention logits w.r.t. relative positions (1, n_head, time1, time2) """ n_batch = query.size(0) q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k) # (b, t, d) k = self.linear_k(key).view(n_batch, -1, self.h_k, self.d_k) # (b, t, d) v = self.linear_v(value).view(n_batch, -1, self.h_k, self.d_k) q = ( q.transpose(1, 2) if self.use_pt_scaled_dot_product_attention and not torch.jit.is_scripting() else q.transpose(1, 2) * self.inv_sqrt_d_k ) k = k.transpose(1, 2) # (batch, head_k, time2, d_k) v = v.transpose(1, 2) # (batch, head_k, time2, d_k) if self.use_pt_scaled_dot_product_attention and not torch.jit.is_scripting(): attn_mask = None if mask is not None: mask = mask.unsqueeze(1) if relative_attention_bias is not None: attn_mask = mask + relative_attention_bias else: attn_mask = mask if mask.dtype != q.dtype: attn_mask = attn_mask.to(q.dtype) with torch.backends.cuda.sdp_kernel( enable_flash=True, enable_math=True, enable_mem_efficient=True ): x = torch.nn.functional.scaled_dot_product_attention( q, k, v, attn_mask=attn_mask, dropout_p=self.dropout_rate, ) else: if self.h != self.h_k: q = q.reshape(n_batch, self.g, self.h_k, -1, self.d_k) A = torch.einsum("b g h t d, b h s d -> b h t s", q, k) else: A = torch.matmul(q, k.transpose(-2, -1)) if pos_k is not None: if self.h != self.h_k: B = torch.einsum("b g h t d, t s d -> b h t s", q, pos_k) else: reshape_q = ( q.contiguous().view(n_batch * self.h, -1, self.d_k).transpose(0, 1) ) # (t1,nh,dk) B = torch.matmul(reshape_q, pos_k.transpose(-2, -1)) # pos_k: (t1,dk,t2) B = B.transpose(0, 1).view(n_batch, self.h, pos_k.size(0), pos_k.size(1)) scores = A + B else: scores = A if relative_attention_bias is not None: scores = scores + relative_attention_bias attn = masked_softmax(scores, mask) # (batch, head, time1, time2) self.attn = attn p_attn = self.dropout(attn) x = torch.matmul(p_attn.to(v.dtype), v) # (batch, head, time1, d_k) if pos_v is not None: reshape_attn = ( p_attn.contiguous() .view(n_batch * self.h, pos_v.size(0), pos_v.size(1)) .transpose(0, 1) ) # (t1, bh, t2) attn_v = ( torch.matmul(reshape_attn, pos_v) .transpose(0, 1) .contiguous() .view(n_batch, self.h, pos_v.size(0), self.d_k) ) x = x + attn_v x = ( x.transpose(1, 2).contiguous().view(n_batch, -1, self.h_k * self.d_k) ) # (batch, time1, d_model) return self.linear_out(x) # (batch, time1, d_model) def unfold_tensor(xs_pad, max_seq_len): """ For a given tensor with shape of (N, T, D), if sequence length T is longer than max_seq_len, this function unfold it to a (NT', max_seq_len, D) where T' is T // max_seq_len. Args: xs_pad: N, T, D """ _, _, D = xs_pad.shape xs_pad = xs_pad.transpose(-1, -2) # convert to N, D, T # N x D x 1 x T => N x (D x max_seq_len) x T' xs_pad = F.unfold( xs_pad[..., None, :], kernel_size=(1, max_seq_len), stride=(1, max_seq_len), ) new_bsz, _, slen = xs_pad.shape # N x D x max_seq_len x T' xs_pad = xs_pad.view(new_bsz, -1, max_seq_len, slen) # N x T' x max_seq_len x D xs_pad = xs_pad.permute(0, 3, 2, 1).contiguous() # NT' x max_seq_len x D xs_pad = xs_pad.view(-1, max_seq_len, D) return xs_pad # conformer_encoder.py class MultiSequential(torch.nn.Sequential): """Multi-input multi-output torch.nn.Sequential""" @torch.jit.ignore def forward(self, *args): """Forward method implementation.""" for m in self: args = m(*args) return args def repeat(repeat_num, module_gen_fn): """repeat module N times :param int repeat_num: repeat time :param function module_gen_fn: function to generate module :return: repeated modules :rtype: MultiSequential """ return MultiSequential(*[module_gen_fn(i) for i in range(repeat_num)]) class ConformerEncoderLayer(nn.Module): """ConformerEncoder Layer module. for more details see conformer paper: https://arxiv.org/abs/2005.08100 This module implement the Conformer block layer. Args: d_model: int attention dim. ext_pw_out_channel: int if > 0, ext_pw_out_channel is a dim channel size for the last pointwise conv after swish activation. depthwise_seperable_out_channel: int if set different to 0, the number of depthwise_seperable_out_channel will be used as a channel_out of the second conv1d layer. otherwise, it equal to 0, the second conv1d layer is skipped. depthwise_multiplier: int number of input_dim channels duplication. this value will be used to compute the hidden channels of the Conv1D. n_head: int the number of heads for multihead attention module. d_ffn: int output size of the feed_forward blocks. ext_pw_kernel_size: int kernel size of the conv pointwise of the conformer. kernel_size: int kernel size. dropout_rate: float dropout rate. causal: bool, optional if set to True, convolution have no access to future frames. default False. batch_norm: bool, optional if set to True, apply batchnorm before activation in ConvModule layer of the conformer. default False activation: str, optional activation function name, one of ["relu", "swish", "sigmoid"], sigmoid activation is only used with "glu_in_fnn=True", default "relu". chunk_se: int, optional 0 for offline SE. 1 for streaming SE, where mean is computed by accumulated history until current chunk_se. 2 for streaming SE, where mean is computed by only the current chunk. default 0. chunk_size: int, optional chunk_size for cnn. default 18 conv_activation: str, optional activation function used in ConvModule part of the conformer, default "relu". conv_glu_type: str, optional activation function used for the glu inside the ConvModule part of the conformer. default: "sigmoid". bias_in_glu: bool, optional if set to True, use additive bias in the weight module before GLU. linear_glu_in_convm: bool, optional if set to True, use GLULinear module, otherwise, used GLUPointWiseConv module. default to False. attention_innner_dim: int, otional if equal to -1, attention dim for linears k/q/v is equal to d_model. otherwise attention_innner_dim is used. default -1. attention_glu_type: str, optional activation function for glu used in the multihead attention, default "swish". activation_checkpointing: str, optional a dictionarry of {"module","interval","offload"}, where "module": str accept ["transformer", "attention"] to select which module should do activation checkpointing. "interval": int, default 1, interval of applying activation checkpointing, interval = 1 means that we apply checkpointing on every layer (if activation), otherwise, we apply it every x interval. "offload": bool, default False, if set to True, we offload activation to cpu and reload it during backward, otherwise, we recalculate activation in backward. default "". export: bool, optional if set to True, it remove the padding from convolutional layers and allow the onnx conversion for inference. default False. use_pt_scaled_dot_product_attention: bool, optional if set to True, use pytorch's scaled dot product attention implementation in training. attn_group_sizes: int, optional the number of groups to use for attention, default 1 (Multi-Head Attention), 1 = typical Multi-Head Attention, 1 < attn_group_sizes < attention_heads = Grouped-Query Attention attn_group_sizes = attenion_heads = Multi-Query Attention """ def __init__( self, d_model=512, ext_pw_out_channel=0, depthwise_seperable_out_channel=256, depthwise_multiplier=1, n_head=4, d_ffn=2048, ext_pw_kernel_size=1, kernel_size=3, dropout_rate=0.1, causal=False, batch_norm=False, activation="relu", chunk_se=0, chunk_size=18, conv_activation="relu", conv_glu_type="sigmoid", bias_in_glu=True, linear_glu_in_convm=False, attention_innner_dim=-1, attention_glu_type="swish", activation_checkpointing="", export=False, use_pt_scaled_dot_product_attention=False, attn_group_sizes: int = 1, ): super().__init__() self.feed_forward_in = FeedForward( d_model=d_model, d_inner=d_ffn, dropout_rate=dropout_rate, activation=activation, bias_in_glu=bias_in_glu, ) self.self_attn = encoder_checkpoint_wrapper( activation_checkpointing, MultiHeadedAttention, )( MultiHeadedAttention( n_head, d_model, dropout_rate, attention_innner_dim, attention_glu_type, bias_in_glu, use_pt_scaled_dot_product_attention=use_pt_scaled_dot_product_attention, group_size=attn_group_sizes, ) ) self.conv = ConvModule( d_model, ext_pw_out_channel, depthwise_seperable_out_channel, ext_pw_kernel_size, kernel_size, depthwise_multiplier, dropout_rate, causal, batch_norm, chunk_se, chunk_size, conv_activation, conv_glu_type, bias_in_glu, linear_glu_in_convm, export=export, ) self.feed_forward_out = FeedForward( d_model=d_model, d_inner=d_ffn, dropout_rate=dropout_rate, activation=activation, bias_in_glu=bias_in_glu, ) self.layer_norm_att = nn.LayerNorm(d_model) self.layer_norm = nn.LayerNorm(d_model) def forward( self, x, pos_k, pos_v, mask, relative_attention_bias: Optional[Tensor] = None, ): """ConformerEncoder forward. Args: x: torch.Tensor input feature of shape (batch, max_time_in, size) pos_k: torch.Tensor positional key embedding. mask: torch.Tensor mask for x (batch, max_time_in) relative_attention_bias: Optional[torch.Tensor] bias added to attention logits w.r.t. relative positions (1, n_head, time1, time2) """ x = x + 0.5 * self.feed_forward_in(x) norm_x = self.layer_norm_att(x) x = x + self.self_attn( norm_x, norm_x, norm_x, pos_k, pos_v, mask, relative_attention_bias=relative_attention_bias, ) x = x + self.conv(x) x = x + 0.5 * self.feed_forward_out(x) out = self.layer_norm(x) return out, pos_k, pos_v, mask class TransformerEncoderBase(abc.ABC, nn.Module): """The Base class for Transformer based encoders Please set causal = True in streaming model Args: input_size: int input feature dimension. chunk_size: int, list(int) Number of frames for each chunk This variable can take 2 forms: int: Used for inference, or single chunk size training list(int) : Used only for variable chunk size training Some examples for the 2 cases: chunk_size = 12 chunk_size = [6, 8, 12, 24] left_chunk: int, list(int) Number of chunks used for masking in streaming mode. This variable can take 2 forms: int: Used for inference, or single chunk size training list(int) : Used only for variable chunk size training. When chunk_size is a list, left_chunk must be a list with same length. Some examples for the 2 cases: left_chunk = 6 left_chunk = [12, 9, 6, 3] attention_dim: int, optional attention dimension. default 256. attention_heads: int, optional the number of heads. default 4 input_layer: str, optional input layer type before Conformer, one of ["linear", "conv2d", "custom", "vgg2l", "embed"], default "conv2d" cnn_out: int, optional the number of CNN channels before Conformer. default -1. cnn_layer_norm: bool, optional layer norm between Conformer and the first CNN. default False. time_reduction: int, optional time reduction factor default 4 dropout_rate: float, optional dropout rate. default 0.1 padding_idx: int, optional padding index for input_layer=embed default -1 relative_attention_bias_args: dict, optional use more efficient scalar bias-based relative multihead attention (Q*K^T + B) implemented in cmb.basics.embedding.[T5/ALiBi]RelativeAttentionLogitBias usage: relative_attention_bias_args={"type": t5/alibi} additional method-specific arguments can be provided (see transformer_base.py) positional_dropout_rate: float, optional dropout rate after positional encoding. default 0.0 nemo_conv_settings: dict, optional A dictionary of settings for NeMo Subsampling. default None conv2d_extra_padding: str, optional Add extra padding in conv2d subsampling layers. Choices are (feat, feat_time, none, True). if True or feat_time, the extra padding is added into non full supraframe utts in batch. Default: none attention_group_size: int, optional the number of groups to use for attention, default 1 (Multi-Head Attention), 1 = typical Multi-Head Attention, 1 < attention_group_size < attention_heads = Grouped-Query Attention attention_group_size = attenion_heads = Multi-Query Attention """ def __init__( self, input_size, chunk_size, left_chunk, attention_dim=256, attention_heads=4, input_layer="nemo_conv", cnn_out=-1, cnn_layer_norm=False, time_reduction=4, dropout_rate=0.0, padding_idx=-1, relative_attention_bias_args=None, positional_dropout_rate=0.0, nemo_conv_settings=None, conv2d_extra_padding: Literal["feat", "feat_time", "none", True] = "none", attention_group_size=1, encoder_embedding_config=None, ): super().__init__() self.input_size = input_size self.input_layer = input_layer self.chunk_size = chunk_size self.left_chunk = left_chunk self.attention_dim = attention_dim self.num_heads = attention_heads self.attention_group_size = attention_group_size self.time_reduction = time_reduction self.nemo_conv_settings = nemo_conv_settings self.encoder_embedding_config = encoder_embedding_config if self.input_layer == "nemo_conv": default_nemo_conv_settings = { "subsampling": "dw_striding", "subsampling_factor": self.time_reduction, "feat_in": input_size, "feat_out": attention_dim, "conv_channels": 256, "subsampling_conv_chunking_factor": 1, "activation": nn.ReLU(), "is_causal": False, } # Override any of the defaults with the incoming, user settings if nemo_conv_settings: default_nemo_conv_settings.update(nemo_conv_settings) for i in ["subsampling_factor", "feat_in", "feat_out"]: assert ( i not in nemo_conv_settings ), "{i} should be specified outside of the NeMo dictionary" self.embed = NemoConvSubsampling( **default_nemo_conv_settings, ) else: raise ValueError("unknown input_layer: " + input_layer) self.pos_emb = AbsolutePositionalEncoding(attention_dim, positional_dropout_rate) self.relative_attention_bias_type = ( relative_attention_bias_args.get("type") if relative_attention_bias_args else None ) if self.relative_attention_bias_type == "t5": assert ( self.num_heads % self.attention_group_size == 0 ), "attention_group_size must divide n_head" self.relative_attention_bias_layer = T5RelativeAttentionLogitBias( self.num_heads // self.attention_group_size, max_distance=relative_attention_bias_args.get("t5_bias_max_distance", 1000), symmetric=relative_attention_bias_args.get("t5_bias_symmetric", False), ) else: raise NotImplementedError def post_init(self, init_model_config): pretrained_speech_encoder_path = init_model_config.get('pretrained_speech_encoder_path', None) if pretrained_speech_encoder_path: model_state = torch.load(pretrained_speech_encoder_path, map_location="cpu") encoder_state_dict = {} for k, v in model_state.items(): if "encoder." in k: tmp_k = k.replace("encoder.", "") encoder_state_dict[tmp_k] = v if hasattr(self, "encoder_embedding"): del self.encoder_embedding self.load_state_dict(encoder_state_dict) if not hasattr(self, "encoder_embedding"): self.encoder_embedding = MeanVarianceNormLayer(self.encoder_embedding_config["input_size"]) mean_file = init_model_config.get('mean_file', None) invstd_file = init_model_config.get('invstd_file', None) if mean_file is not None and invstd_file is not None: self.encoder_embedding.load_mean_invstd(mean_file, invstd_file) def compute_lens_change(self, feature_lens): """feature_lens: int return updated feature lens. This used to return a different lambda function for each case that computed the right thing. That does not work within Torchscript. If you really need this to be faster, create nn.Module()-s for all the cases and return one of them. Torchscript does support that. """ if self.input_layer == "nemo_conv": # Handle the special causal case subsampling_causal_cond = self.nemo_conv_settings.get("subsampling", "dw_striding") in [ "dw_striding", "striding", "striding_conv1d", ] is_causal = self.nemo_conv_settings.get("is_causal", False) if is_causal and subsampling_causal_cond: lens_change = ( torch.ceil(feature_lens / self.time_reduction).long() if isinstance(feature_lens, Tensor) else math.ceil(feature_lens / self.time_reduction) ) feature_lens_remainder = feature_lens % self.time_reduction if isinstance(feature_lens, Tensor): lens_change[feature_lens_remainder != 1] += 1 elif feature_lens_remainder != 1: lens_change += 1 return lens_change ceil_func = math.ceil if isinstance(feature_lens, int) else torch.ceil return ceil_func(feature_lens / self.time_reduction) @abc.abstractmethod def forward(self): """Abstract forward method implementation.""" def _chunk_size_selection(self, chunk_size=None, left_chunk=None): """If chunk size is a list, we will randomly select a chunk size.""" if chunk_size is None: chunk_size = self.chunk_size if left_chunk is None: left_chunk = self.left_chunk if isinstance(chunk_size, list): # Variable chunk size during training chunk_size_index = int(torch.randint(low=0, high=len(chunk_size), size=(1,))) chunk_size_train_eff = chunk_size[chunk_size_index] if not isinstance(left_chunk, list): raise ValueError("Since chunk_size is a list, left_chunk must be a list") if len(left_chunk) != len(chunk_size): raise ValueError( "The length of left_chunk must be the same as length of chunk_size." ) left_chunk_train_eff = left_chunk[chunk_size_index] else: chunk_size_train_eff = chunk_size left_chunk_train_eff = left_chunk return chunk_size_train_eff, left_chunk_train_eff def _get_embed_class(self, embed): # pylint: disable=protected-access is_embed_using_act_chkpt = isinstance(embed, CheckpointWrapper) is_embed_fsdp_wrapped = isinstance(embed, FullyShardedDataParallel) embed_class = embed if is_embed_using_act_chkpt: embed_class = embed._checkpoint_wrapped_module if is_embed_fsdp_wrapped: embed_class = embed.module return embed_class def _forward_embeddings_core(self, input_tensor, masks): embed_class = self._get_embed_class(self.embed) assert isinstance(embed_class, NemoConvSubsampling) input_tensor, masks = self.embed(input_tensor, masks) return input_tensor, masks def _position_embedding(self, input_tensor): pos_k = None pos_v = None if self.relative_attention_bias_layer is None: input_tensor = self.pos_emb(input_tensor) # default to add abs sinusoid embedding return pos_k, pos_v def _streaming_mask(self, seq_len, batch_size, chunk_size, left_chunk): chunk_size_train_eff, left_chunk_train_eff = self._chunk_size_selection( chunk_size, left_chunk ) # Create mask matrix for streaming # S stores start index. if chunksize is 18, s is [0,18,36,....] chunk_start_idx = np.arange(0, seq_len, chunk_size_train_eff) # avoid randomness when run evaluation or decoding if self.training and np.random.rand() > 0.5: # Either first or last chunk is not complete. # If only the last one is not complete, EOS is not effective chunk_start_idx = seq_len - chunk_start_idx chunk_start_idx = chunk_start_idx[::-1] chunk_start_idx = chunk_start_idx[:-1] chunk_start_idx = np.insert(chunk_start_idx, 0, 0) enc_streaming_mask = ( adaptive_enc_mask(seq_len, chunk_start_idx, left_window=left_chunk_train_eff) .unsqueeze(0) .expand([batch_size, -1, -1]) ) return enc_streaming_mask def forward_embeddings(self, xs_pad, masks, chunk_size_nc=None, left_chunk_nc=None): """Forwarding the inputs through the top embedding layers Args: xs_pad: torch.Tensor input tensor masks: torch.Tensor input mask chunk_size_nc: (optional, default is None) chunk size for non-causal layers left_chunk_nc: (optional, default is None) # of left chunks for non-causal layers """ # pylint: disable=R0915 # get new lens. seq_len = int(self.compute_lens_change(xs_pad.shape[1])) if seq_len <= 0: raise ValueError( f"""The squence length after time reduction is invalid: {seq_len}. Your input feature is too short. Consider filtering out the very short sentence from data loader""", ) batch_size = xs_pad.shape[0] enc_streaming_mask = self._streaming_mask( seq_len, batch_size, self.chunk_size, self.left_chunk ) if xs_pad.is_cuda: enc_streaming_mask = enc_streaming_mask.cuda() xs_pad = xs_pad.cuda() input_tensor = xs_pad input_tensor, masks = self._forward_embeddings_core(input_tensor, masks) streaming_mask = enc_streaming_mask if streaming_mask is not None and masks is not None: hs_mask = masks & streaming_mask elif masks is not None: hs_mask = masks else: hs_mask = streaming_mask if chunk_size_nc is not None: enc_streaming_mask_nc = self._streaming_mask( seq_len, batch_size, chunk_size_nc, left_chunk_nc ) if xs_pad.is_cuda: enc_streaming_mask_nc = enc_streaming_mask_nc.cuda() if masks is not None: hs_mask_nc = masks & enc_streaming_mask_nc else: hs_mask_nc = enc_streaming_mask_nc else: hs_mask_nc = None pos_k, pos_v = self._position_embedding(input_tensor) if chunk_size_nc is None: return input_tensor, pos_k, pos_v, hs_mask, masks return input_tensor, pos_k, pos_v, hs_mask, masks, hs_mask_nc def get_offset(self): """Returns offset used when retaining inputs for decoding. This is essentially, how many additional frames have to be added to the front-end CNN input to ensure it can produce a single output. So if the "padding" parameter is 0, typically offset will be > 0. """ return get_offset(self.input_layer, self.time_reduction) def get_offset(input_layer: str, time_reduction: int): """Get an offset. We will use the offset for determining #frames of a subsampled feature. Args: input_layer (str): Type of an input layer time_reduction (int): time reduction factor for downsampling a feature Returns: int: offset """ if input_layer in ("conv2d", "nemo_conv") and time_reduction == 4: return 3 if input_layer in ("conv2d",) and time_reduction == 6: return 1 if input_layer in ("conv2d", "nemo_conv") and time_reduction == 8: return 7 return 0 class ConformerEncoder(TransformerEncoderBase): """ConformerEncoder module. see original paper for more details: https://arxiv.org/abs/2005.08100 Please set causal = True in streaming model Args: input_size: int input feature dimension. chunk_size: int, list(int) Number of frames for each chunk This variable can take 2 forms: int: Used for inference, or single chunk size training list(int) : Used only for variable chunk size training Some examples for the 2 cases: chunk_size = 12 chunk_size = [6, 8, 12, 24] left_chunk: int, list(int) Number of chunks used for masking in streaming mode. This variable can take 2 forms: int: Used for inference, or single chunk size training list(int) : Used only for variable chunk size training. When chunk_size is a list, left_chunk must be a list with same length. Some examples for the 2 cases: left_chunk = 6 left_chunk = [12, 9, 6, 3] left_chunk: int number of chunks used for masking in streaming mode. num_lang: int This parameter is used to store the number of languages in the lang_dict, only used for multiseed/multilingual models. default None. attention_dim: int, optional attention dimension. default 256. attention_heads: int, optional the number of heads. default 4 linear_units: the number of units of position-wise feed forward. default 2048 num_block: number of Transformer layer. default 6 dropout_rate: float, optional dropout rate. default 0.1 input_layer: str, optional input layer type before Conformer, one of ["linear", "conv2d", "custom", "vgg2l", "embed"], default "conv2d" causal: bool, optional if set to True, convolution have no access to future frames. default False. batch_norm: bool, optional if set to True, apply batchnorm before activation in ConvModule layer of the conformer. default False cnn_out: int, optional the number of CNN channels before Conformer. default -1. cnn_layer_norm: bool, optional layer norm between Conformer and the first CNN. default False. ext_pw_out_channel: int, optional the number of channel for CNN before depthwise_seperable_CNN. If 0 then use linear. default 0. ext_pw_kernel_size: int, optional kernel size of N before depthwise_seperable_CNN. only work for ext_pw_out_channel > 0. default 1 depthwise_seperable_out_channel: int, optional the number of channel for depthwise_seperable_CNN. default 256. depthwise_multiplier: int, optional the number of multiplier for depthwise_seperable_CNN. default 1. chunk_se: int, optional 0 for offline SE. 1 for streaming SE, where mean is computed by accumulated history until current chunk_se. 2 for streaming SE, where mean is computed by only the current chunk. default 0. kernel_size: int, optional the number of kernels for depthwise_seperable_CNN. default 3. activation: str, optional FeedForward block activation. one of ["relu", "swish", "sigmoid"] default "relu". conv_activation: str, optional activation function used in ConvModule part of the conformer, default "relu". conv_glu_type: str, otional activation used use glu in depthwise_seperable_CNN, default "sigmoid" bias_in_glu: bool, optional if set to True, use additive bias in the weight module before GLU. default True linear_glu_in_convm: bool, optional if set to True, use GLULinear module, otherwise, used GLUPointWiseConv module. default to False. attention_glu_type: str only work for glu_in_attention !=0 default "swish". export: bool, optional if set to True, it remove the padding from convolutional layers and allow the onnx conversion for inference. default False. activation_checkpointing: str, optional a dictionarry of {"module","interval","offload"}, where "module": str accept ["transformer", "attention"] to select which module should do activation checkpointing. "interval": int, default 1, interval of applying activation checkpointing, interval = 1 means that we apply checkpointing on every layer (if activation), otherwise, we apply it every x interval. "offload": bool, default False, if set to True, we offload activation to cpu and reload it during backward, otherwise, we recalculate activation in backward. default "". extra_layer_output_idx: int the layer index to be exposed. relative_attention_bias_args: dict, optional use more efficient scalar bias-based relative multihead attention (Q*K^T + B) implemented in cmb.basics.embedding.[T5/ALiBi]RelativeAttentionLogitBias usage: relative_attention_bias_args={"type": t5/alibi} additional method-specific arguments can be provided (see transformer_base.py) time_reduction: int optional time reduction factor default 4 use_pt_scaled_dot_product_attention: whether to use pytorch scaled dot product attention in training. Default: False nemo_conv_settings: dict, optional A dictionary of settings for NeMo Subsampling. default: None usage: nemo_conv_settings= { "subsampling": dw_striding/striding/dw_striding_conv1d/striding_conv1d, "conv_channels": int, "subsampling_conv_chunking_factor": int, "is_causal": True/False } conv2d_extra_padding: str, optional Add extra padding in conv2d subsampling layers. Choices are (feat, feat_time, none, True) Default: none replication_pad_for_subsample_embedding: For batched-streaming decoding, use "replication" padding for the cache at start of utterance. Default: False attention_group_size: int, optional the number of groups to use for attention, default 1 (Multi-Head Attention), 1 = typical Multi-Head Attention, 1 < attention_group_size < attention_heads = Grouped-Query Attention attention_group_size = attenion_heads = Multi-Query Attention """ extra_multi_layer_output_idxs: List[int] def __init__( # pylint: disable-all self, input_size, chunk_size, left_chunk, num_lang=None, attention_dim=256, attention_heads=4, linear_units=2048, num_blocks=6, dropout_rate=0.1, input_layer="nemo_conv", causal=True, batch_norm=False, cnn_out=-1, cnn_layer_norm=False, ext_pw_out_channel=0, ext_pw_kernel_size=1, depthwise_seperable_out_channel=256, depthwise_multiplier=1, chunk_se=0, kernel_size=3, activation="relu", conv_activation="relu", conv_glu_type="sigmoid", bias_in_glu=True, linear_glu_in_convm=False, attention_glu_type="swish", export=False, extra_layer_output_idx=-1, extra_multi_layer_output_idxs=[], activation_checkpointing="", relative_attention_bias_args=None, time_reduction=4, use_pt_scaled_dot_product_attention=False, nemo_conv_settings=None, conv2d_extra_padding: Literal["feat", "feat_time", "none", True] = "none", replication_pad_for_subsample_embedding=False, attention_group_size=1, encoder_embedding_config=None, ): super().__init__( input_size, chunk_size, left_chunk, attention_dim, attention_heads, input_layer, cnn_out, cnn_layer_norm, time_reduction, dropout_rate=dropout_rate, relative_attention_bias_args=relative_attention_bias_args, positional_dropout_rate=0.0, nemo_conv_settings=nemo_conv_settings, conv2d_extra_padding=conv2d_extra_padding, attention_group_size=attention_group_size, encoder_embedding_config=encoder_embedding_config, ) self.num_blocks = num_blocks self.num_lang = num_lang self.kernel_size = kernel_size self.embed = embedding_checkpoint_wrapper(activation_checkpointing)(self.embed) self.replication_pad_for_subsample_embedding: bool = replication_pad_for_subsample_embedding assert self.num_heads % attention_group_size == 0, "attention_group_size must divide n_head" self.num_heads_k = self.num_heads // attention_group_size self.encoders = repeat( num_blocks, lambda i: encoder_checkpoint_wrapper( activation_checkpointing, ConformerEncoderLayer, i )( ConformerEncoderLayer( d_model=attention_dim, ext_pw_out_channel=ext_pw_out_channel, depthwise_seperable_out_channel=depthwise_seperable_out_channel, depthwise_multiplier=depthwise_multiplier, n_head=attention_heads, d_ffn=linear_units, ext_pw_kernel_size=ext_pw_kernel_size, kernel_size=kernel_size, dropout_rate=dropout_rate, causal=causal, batch_norm=batch_norm, activation=activation, chunk_se=chunk_se, chunk_size=chunk_size, conv_activation=conv_activation, conv_glu_type=conv_glu_type, bias_in_glu=bias_in_glu, linear_glu_in_convm=linear_glu_in_convm, attention_glu_type=attention_glu_type, activation_checkpointing=attn_checkpointing(activation_checkpointing, i), export=export, use_pt_scaled_dot_product_attention=use_pt_scaled_dot_product_attention, attn_group_sizes=attention_group_size, ) ), ) self.extra_layer_output_idx = extra_layer_output_idx self.extra_multi_layer_output_idxs = extra_multi_layer_output_idxs # Make a zeros scalar we can use in get_initial_state to determine # the device and the needed dtype: self.register_buffer("dev_type", torch.zeros(()), persistent=False) def init_relative_attention_bias(self, input_tensor): if self.relative_attention_bias_layer: return self.relative_attention_bias_layer(input_tensor) def calculate_hs_mask(self, xs_pad, device, mask): max_audio_length = xs_pad.shape[1] batch_size = xs_pad.shape[0] enc_streaming_mask = self._streaming_mask( max_audio_length, batch_size, self.chunk_size, self.left_chunk ) enc_streaming_mask = enc_streaming_mask.to(device) if mask is None: return enc_streaming_mask feature_lens = mask.sum(1) padding_length = feature_lens pad_mask = ( torch.arange(0, max_audio_length, device=device).expand(padding_length.size(0), -1) < padding_length.unsqueeze(1) ) pad_mask = pad_mask.unsqueeze(1) pad_mask = pad_mask & enc_streaming_mask return pad_mask @torch.jit.ignore def forward(self, xs_pad, masks): """Conformer Forward function Args: xs_pad: torch.Tensor input tensor masks: torch.Tensor post-embedding input lengths """ xs_pad = self.encoder_embedding(xs_pad) input_tensor, pos_k, pos_v, hs_mask, masks = self.forward_embeddings(xs_pad, masks) unfolded = False ori_bz, seq_len, D = input_tensor.shape max_seq_len = 500 #maxium position for absolute positional encoding if seq_len > max_seq_len: # audio sequence is longer than max_seq_len, unfold it into chunks of max_seq_len unfolded = True # the unfold op will drop residual frames, pad it to the multiple of max_seq_len if seq_len % max_seq_len > 0: chunk_pad_size = max_seq_len - (seq_len % max_seq_len) else: chunk_pad_size = 0 if chunk_pad_size > 0: input_tensor_pad = F.pad(input_tensor, (0, 0, 0, chunk_pad_size), "constant", 0) input_tensor = input_tensor_pad.to(input_tensor.device) input_tensor = unfold_tensor(input_tensor, max_seq_len) if masks is not None: # revise hs_mask here because the previous calculated hs_mask did not consider extra pad subsampled_pad_mask = masks.squeeze(1) # [bz, subsampled_unmask_seq_len] extra_padded_subsamlped_pad_mask = F.pad(subsampled_pad_mask, (0, chunk_pad_size), "constant", False) # extra padding to the pad mask extra_padded_subsamlped_pad_mask = extra_padded_subsamlped_pad_mask.unsqueeze(-1).float() masks_unfold = unfold_tensor(extra_padded_subsamlped_pad_mask, max_seq_len) # unfold the pad mask like we did to the input tensor masks_unfold = masks_unfold.squeeze(-1).bool() # unfold op does not support bool tensor else: masks_unfold = None hs_mask = self.calculate_hs_mask(input_tensor, input_tensor.device, masks_unfold) # calculate hs_mask based on the unfolded pad mask layer_emb = None relative_attention_bias = self.init_relative_attention_bias(input_tensor) _simplified_path = ( self.extra_layer_output_idx == -1 and relative_attention_bias is None ) if _simplified_path: input_tensor, *_ = self.encoders(input_tensor, pos_k, pos_v, hs_mask) else: for i, layer in enumerate(self.encoders): input_tensor, _, _, _ = layer( input_tensor, pos_k, pos_v, hs_mask, relative_attention_bias=relative_attention_bias, ) if i == self.extra_layer_output_idx: layer_emb = input_tensor if unfolded: embed_dim = input_tensor.shape[-1] input_tensor = input_tensor.reshape(ori_bz, -1, embed_dim) # if we ever padded before unfolding, we need to remove the padding if chunk_pad_size > 0: input_tensor = input_tensor[:, :-chunk_pad_size, :] return input_tensor, masks #, layer_emb def gradient_checkpointing_enable(self): pass