# Copyright (c) 2022 Yifan Peng (Carnegie Mellon University) # 2023 Voicecomm Inc (Kai Li) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Modified from ESPnet(https://github.com/espnet/espnet) """Encoder definition.""" import torch from typing import List, Optional, Union from wenet.branchformer.encoder_layer import BranchformerEncoderLayer from wenet.branchformer.cgmlp import ConvolutionalGatingMLP from wenet.transformer.encoder import BaseEncoder from wenet.utils.class_utils import ( WENET_ATTENTION_CLASSES, ) class BranchformerEncoder(BaseEncoder): """Branchformer encoder module.""" def __init__( self, input_size: int, output_size: int = 256, use_attn: bool = True, attention_heads: int = 4, selfattention_layer_type: str = "rel_selfattn", pos_enc_layer_type: str = "rel_pos", use_cgmlp: bool = True, cgmlp_linear_units: int = 2048, cgmlp_conv_kernel: int = 31, use_linear_after_conv: bool = False, gate_activation: str = "identity", merge_method: str = "concat", cgmlp_weight: Union[float, List[float]] = 0.5, attn_branch_drop_rate: Union[float, List[float]] = 0.0, num_blocks: int = 12, dropout_rate: float = 0.1, positional_dropout_rate: float = 0.1, attention_dropout_rate: float = 0.0, input_layer: str = "conv2d", stochastic_depth_rate: Union[float, List[float]] = 0.0, static_chunk_size: int = 0, use_dynamic_chunk: bool = False, global_cmvn: torch.nn.Module = None, use_dynamic_left_chunk: bool = False, causal: bool = False, query_bias: bool = True, key_bias: bool = True, value_bias: bool = True, gradient_checkpointing: bool = False, use_sdpa: bool = False, layer_norm_type: str = 'layer_norm', norm_eps: float = 1e-5, n_kv_head: Optional[int] = None, head_dim: Optional[int] = None, ): super().__init__(input_size, output_size, attention_heads, cgmlp_linear_units, num_blocks, dropout_rate, positional_dropout_rate, attention_dropout_rate, input_layer, pos_enc_layer_type, True, static_chunk_size, use_dynamic_chunk, global_cmvn, use_dynamic_left_chunk, gradient_checkpointing, use_sdpa, layer_norm_type, norm_eps) encoder_selfattn_layer_args = ( attention_heads, output_size, attention_dropout_rate, query_bias, key_bias, value_bias, use_sdpa, n_kv_head, head_dim, ) cgmlp_layer = ConvolutionalGatingMLP cgmlp_layer_args = ( output_size, cgmlp_linear_units, cgmlp_conv_kernel, dropout_rate, use_linear_after_conv, gate_activation, causal, ) if isinstance(stochastic_depth_rate, float): stochastic_depth_rate = [stochastic_depth_rate] * num_blocks if len(stochastic_depth_rate) != num_blocks: raise ValueError( f"Length of stochastic_depth_rate ({len(stochastic_depth_rate)}) " f"should be equal to num_blocks ({num_blocks})") if isinstance(cgmlp_weight, float): cgmlp_weight = [cgmlp_weight] * num_blocks if len(cgmlp_weight) != num_blocks: raise ValueError( f"Length of cgmlp_weight ({len(cgmlp_weight)}) should be equal to " f"num_blocks ({num_blocks})") if isinstance(attn_branch_drop_rate, float): attn_branch_drop_rate = [attn_branch_drop_rate] * num_blocks if len(attn_branch_drop_rate) != num_blocks: raise ValueError( f"Length of attn_branch_drop_rate ({len(attn_branch_drop_rate)}) " f"should be equal to num_blocks ({num_blocks})") self.encoders = LayerDropModuleList( p=stochastic_depth_rate, modules=[ BranchformerEncoderLayer( output_size, WENET_ATTENTION_CLASSES[selfattention_layer_type]( *encoder_selfattn_layer_args) if use_attn else None, cgmlp_layer(*cgmlp_layer_args) if use_cgmlp else None, dropout_rate, merge_method, cgmlp_weight[lnum], attn_branch_drop_rate[lnum], stochastic_depth_rate[lnum], ) for lnum in range(num_blocks) ]) # modify from : https://github.com/facebookresearch/fairseq/blob/main/fairseq/modules/layer_drop.py # noqa class LayerDropModuleList(torch.nn.ModuleList): """ A LayerDrop implementation based on :class:`torch.nn.ModuleList`. We refresh the choice of which layers to drop every time we iterate over the LayerDropModuleList instance. During evaluation we always iterate over all layers. Usage:: layers = LayerDropList(p=0.5, modules=[layer1, layer2, layer3]) for layer in layers: # this might iterate over layers 1 and 3 x = layer(x) for layer in layers: # this might iterate over all layers x = layer(x) for layer in layers: # this might not iterate over any layers x = layer(x) Args: p (float): probability of dropping out each layer modules (iterable, optional): an iterable of modules to add Limitations: 1 can work with ddp when layer's gradient checkpoint disabled 2 can't work with ddp when layer's gradient checkpoint enables 3 can work with fsdp 4 can work with deepspeed """ def __init__(self, p: List[float], modules=None): super().__init__(modules) assert len(p) == len(self) self.p = p def __iter__(self): dropout_probs = torch.empty(len(self)).uniform_() for i, m in enumerate(super().__iter__()): if not self.training or (dropout_probs[i] > self.p[i]): yield m