Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) 2022 Yifan Peng (Carnegie Mellon University) | |
# 2023 Voicecomm Inc (Kai Li) | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# Modified from ESPnet(https://github.com/espnet/espnet) | |
"""Encoder definition.""" | |
import torch | |
from typing import List, Optional, Union | |
from wenet.branchformer.encoder_layer import BranchformerEncoderLayer | |
from wenet.branchformer.cgmlp import ConvolutionalGatingMLP | |
from wenet.transformer.encoder import BaseEncoder | |
from wenet.utils.class_utils import ( | |
WENET_ATTENTION_CLASSES, ) | |
class BranchformerEncoder(BaseEncoder): | |
"""Branchformer encoder module.""" | |
def __init__( | |
self, | |
input_size: int, | |
output_size: int = 256, | |
use_attn: bool = True, | |
attention_heads: int = 4, | |
selfattention_layer_type: str = "rel_selfattn", | |
pos_enc_layer_type: str = "rel_pos", | |
use_cgmlp: bool = True, | |
cgmlp_linear_units: int = 2048, | |
cgmlp_conv_kernel: int = 31, | |
use_linear_after_conv: bool = False, | |
gate_activation: str = "identity", | |
merge_method: str = "concat", | |
cgmlp_weight: Union[float, List[float]] = 0.5, | |
attn_branch_drop_rate: Union[float, List[float]] = 0.0, | |
num_blocks: int = 12, | |
dropout_rate: float = 0.1, | |
positional_dropout_rate: float = 0.1, | |
attention_dropout_rate: float = 0.0, | |
input_layer: str = "conv2d", | |
stochastic_depth_rate: Union[float, List[float]] = 0.0, | |
static_chunk_size: int = 0, | |
use_dynamic_chunk: bool = False, | |
global_cmvn: torch.nn.Module = None, | |
use_dynamic_left_chunk: bool = False, | |
causal: bool = False, | |
query_bias: bool = True, | |
key_bias: bool = True, | |
value_bias: bool = True, | |
gradient_checkpointing: bool = False, | |
use_sdpa: bool = False, | |
layer_norm_type: str = 'layer_norm', | |
norm_eps: float = 1e-5, | |
n_kv_head: Optional[int] = None, | |
head_dim: Optional[int] = None, | |
): | |
super().__init__(input_size, output_size, attention_heads, | |
cgmlp_linear_units, num_blocks, dropout_rate, | |
positional_dropout_rate, attention_dropout_rate, | |
input_layer, pos_enc_layer_type, True, | |
static_chunk_size, use_dynamic_chunk, global_cmvn, | |
use_dynamic_left_chunk, gradient_checkpointing, | |
use_sdpa, layer_norm_type, norm_eps) | |
encoder_selfattn_layer_args = ( | |
attention_heads, | |
output_size, | |
attention_dropout_rate, | |
query_bias, | |
key_bias, | |
value_bias, | |
use_sdpa, | |
n_kv_head, | |
head_dim, | |
) | |
cgmlp_layer = ConvolutionalGatingMLP | |
cgmlp_layer_args = ( | |
output_size, | |
cgmlp_linear_units, | |
cgmlp_conv_kernel, | |
dropout_rate, | |
use_linear_after_conv, | |
gate_activation, | |
causal, | |
) | |
if isinstance(stochastic_depth_rate, float): | |
stochastic_depth_rate = [stochastic_depth_rate] * num_blocks | |
if len(stochastic_depth_rate) != num_blocks: | |
raise ValueError( | |
f"Length of stochastic_depth_rate ({len(stochastic_depth_rate)}) " | |
f"should be equal to num_blocks ({num_blocks})") | |
if isinstance(cgmlp_weight, float): | |
cgmlp_weight = [cgmlp_weight] * num_blocks | |
if len(cgmlp_weight) != num_blocks: | |
raise ValueError( | |
f"Length of cgmlp_weight ({len(cgmlp_weight)}) should be equal to " | |
f"num_blocks ({num_blocks})") | |
if isinstance(attn_branch_drop_rate, float): | |
attn_branch_drop_rate = [attn_branch_drop_rate] * num_blocks | |
if len(attn_branch_drop_rate) != num_blocks: | |
raise ValueError( | |
f"Length of attn_branch_drop_rate ({len(attn_branch_drop_rate)}) " | |
f"should be equal to num_blocks ({num_blocks})") | |
self.encoders = LayerDropModuleList( | |
p=stochastic_depth_rate, | |
modules=[ | |
BranchformerEncoderLayer( | |
output_size, | |
WENET_ATTENTION_CLASSES[selfattention_layer_type]( | |
*encoder_selfattn_layer_args) if use_attn else None, | |
cgmlp_layer(*cgmlp_layer_args) if use_cgmlp else None, | |
dropout_rate, | |
merge_method, | |
cgmlp_weight[lnum], | |
attn_branch_drop_rate[lnum], | |
stochastic_depth_rate[lnum], | |
) for lnum in range(num_blocks) | |
]) | |
# modify from : https://github.com/facebookresearch/fairseq/blob/main/fairseq/modules/layer_drop.py # noqa | |
class LayerDropModuleList(torch.nn.ModuleList): | |
""" | |
A LayerDrop implementation based on :class:`torch.nn.ModuleList`. | |
We refresh the choice of which layers to drop every time we iterate | |
over the LayerDropModuleList instance. During evaluation we always | |
iterate over all layers. | |
Usage:: | |
layers = LayerDropList(p=0.5, modules=[layer1, layer2, layer3]) | |
for layer in layers: # this might iterate over layers 1 and 3 | |
x = layer(x) | |
for layer in layers: # this might iterate over all layers | |
x = layer(x) | |
for layer in layers: # this might not iterate over any layers | |
x = layer(x) | |
Args: | |
p (float): probability of dropping out each layer | |
modules (iterable, optional): an iterable of modules to add | |
Limitations: | |
1 can work with ddp when layer's gradient checkpoint disabled | |
2 can't work with ddp when layer's gradient checkpoint enables | |
3 can work with fsdp | |
4 can work with deepspeed | |
""" | |
def __init__(self, p: List[float], modules=None): | |
super().__init__(modules) | |
assert len(p) == len(self) | |
self.p = p | |
def __iter__(self): | |
dropout_probs = torch.empty(len(self)).uniform_() | |
for i, m in enumerate(super().__iter__()): | |
if not self.training or (dropout_probs[i] > self.p[i]): | |
yield m | |