Spaces:
Running
on
Zero
Running
on
Zero
File size: 6,778 Bytes
568e264 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 |
# Copyright (c) 2022 Yifan Peng (Carnegie Mellon University)
# 2023 Voicecomm Inc (Kai Li)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from ESPnet(https://github.com/espnet/espnet)
"""Encoder definition."""
import torch
from typing import List, Optional, Union
from wenet.branchformer.encoder_layer import BranchformerEncoderLayer
from wenet.branchformer.cgmlp import ConvolutionalGatingMLP
from wenet.transformer.encoder import BaseEncoder
from wenet.utils.class_utils import (
WENET_ATTENTION_CLASSES, )
class BranchformerEncoder(BaseEncoder):
"""Branchformer encoder module."""
def __init__(
self,
input_size: int,
output_size: int = 256,
use_attn: bool = True,
attention_heads: int = 4,
selfattention_layer_type: str = "rel_selfattn",
pos_enc_layer_type: str = "rel_pos",
use_cgmlp: bool = True,
cgmlp_linear_units: int = 2048,
cgmlp_conv_kernel: int = 31,
use_linear_after_conv: bool = False,
gate_activation: str = "identity",
merge_method: str = "concat",
cgmlp_weight: Union[float, List[float]] = 0.5,
attn_branch_drop_rate: Union[float, List[float]] = 0.0,
num_blocks: int = 12,
dropout_rate: float = 0.1,
positional_dropout_rate: float = 0.1,
attention_dropout_rate: float = 0.0,
input_layer: str = "conv2d",
stochastic_depth_rate: Union[float, List[float]] = 0.0,
static_chunk_size: int = 0,
use_dynamic_chunk: bool = False,
global_cmvn: torch.nn.Module = None,
use_dynamic_left_chunk: bool = False,
causal: bool = False,
query_bias: bool = True,
key_bias: bool = True,
value_bias: bool = True,
gradient_checkpointing: bool = False,
use_sdpa: bool = False,
layer_norm_type: str = 'layer_norm',
norm_eps: float = 1e-5,
n_kv_head: Optional[int] = None,
head_dim: Optional[int] = None,
):
super().__init__(input_size, output_size, attention_heads,
cgmlp_linear_units, num_blocks, dropout_rate,
positional_dropout_rate, attention_dropout_rate,
input_layer, pos_enc_layer_type, True,
static_chunk_size, use_dynamic_chunk, global_cmvn,
use_dynamic_left_chunk, gradient_checkpointing,
use_sdpa, layer_norm_type, norm_eps)
encoder_selfattn_layer_args = (
attention_heads,
output_size,
attention_dropout_rate,
query_bias,
key_bias,
value_bias,
use_sdpa,
n_kv_head,
head_dim,
)
cgmlp_layer = ConvolutionalGatingMLP
cgmlp_layer_args = (
output_size,
cgmlp_linear_units,
cgmlp_conv_kernel,
dropout_rate,
use_linear_after_conv,
gate_activation,
causal,
)
if isinstance(stochastic_depth_rate, float):
stochastic_depth_rate = [stochastic_depth_rate] * num_blocks
if len(stochastic_depth_rate) != num_blocks:
raise ValueError(
f"Length of stochastic_depth_rate ({len(stochastic_depth_rate)}) "
f"should be equal to num_blocks ({num_blocks})")
if isinstance(cgmlp_weight, float):
cgmlp_weight = [cgmlp_weight] * num_blocks
if len(cgmlp_weight) != num_blocks:
raise ValueError(
f"Length of cgmlp_weight ({len(cgmlp_weight)}) should be equal to "
f"num_blocks ({num_blocks})")
if isinstance(attn_branch_drop_rate, float):
attn_branch_drop_rate = [attn_branch_drop_rate] * num_blocks
if len(attn_branch_drop_rate) != num_blocks:
raise ValueError(
f"Length of attn_branch_drop_rate ({len(attn_branch_drop_rate)}) "
f"should be equal to num_blocks ({num_blocks})")
self.encoders = LayerDropModuleList(
p=stochastic_depth_rate,
modules=[
BranchformerEncoderLayer(
output_size,
WENET_ATTENTION_CLASSES[selfattention_layer_type](
*encoder_selfattn_layer_args) if use_attn else None,
cgmlp_layer(*cgmlp_layer_args) if use_cgmlp else None,
dropout_rate,
merge_method,
cgmlp_weight[lnum],
attn_branch_drop_rate[lnum],
stochastic_depth_rate[lnum],
) for lnum in range(num_blocks)
])
# modify from : https://github.com/facebookresearch/fairseq/blob/main/fairseq/modules/layer_drop.py # noqa
class LayerDropModuleList(torch.nn.ModuleList):
"""
A LayerDrop implementation based on :class:`torch.nn.ModuleList`.
We refresh the choice of which layers to drop every time we iterate
over the LayerDropModuleList instance. During evaluation we always
iterate over all layers.
Usage::
layers = LayerDropList(p=0.5, modules=[layer1, layer2, layer3])
for layer in layers: # this might iterate over layers 1 and 3
x = layer(x)
for layer in layers: # this might iterate over all layers
x = layer(x)
for layer in layers: # this might not iterate over any layers
x = layer(x)
Args:
p (float): probability of dropping out each layer
modules (iterable, optional): an iterable of modules to add
Limitations:
1 can work with ddp when layer's gradient checkpoint disabled
2 can't work with ddp when layer's gradient checkpoint enables
3 can work with fsdp
4 can work with deepspeed
"""
def __init__(self, p: List[float], modules=None):
super().__init__(modules)
assert len(p) == len(self)
self.p = p
def __iter__(self):
dropout_probs = torch.empty(len(self)).uniform_()
for i, m in enumerate(super().__iter__()):
if not self.training or (dropout_probs[i] > self.p[i]):
yield m
|