OSUM / wenet /branchformer /encoder.py
tomxxie
适配zeroGPU
568e264
raw
history blame
6.78 kB
# Copyright (c) 2022 Yifan Peng (Carnegie Mellon University)
# 2023 Voicecomm Inc (Kai Li)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from ESPnet(https://github.com/espnet/espnet)
"""Encoder definition."""
import torch
from typing import List, Optional, Union
from wenet.branchformer.encoder_layer import BranchformerEncoderLayer
from wenet.branchformer.cgmlp import ConvolutionalGatingMLP
from wenet.transformer.encoder import BaseEncoder
from wenet.utils.class_utils import (
WENET_ATTENTION_CLASSES, )
class BranchformerEncoder(BaseEncoder):
"""Branchformer encoder module."""
def __init__(
self,
input_size: int,
output_size: int = 256,
use_attn: bool = True,
attention_heads: int = 4,
selfattention_layer_type: str = "rel_selfattn",
pos_enc_layer_type: str = "rel_pos",
use_cgmlp: bool = True,
cgmlp_linear_units: int = 2048,
cgmlp_conv_kernel: int = 31,
use_linear_after_conv: bool = False,
gate_activation: str = "identity",
merge_method: str = "concat",
cgmlp_weight: Union[float, List[float]] = 0.5,
attn_branch_drop_rate: Union[float, List[float]] = 0.0,
num_blocks: int = 12,
dropout_rate: float = 0.1,
positional_dropout_rate: float = 0.1,
attention_dropout_rate: float = 0.0,
input_layer: str = "conv2d",
stochastic_depth_rate: Union[float, List[float]] = 0.0,
static_chunk_size: int = 0,
use_dynamic_chunk: bool = False,
global_cmvn: torch.nn.Module = None,
use_dynamic_left_chunk: bool = False,
causal: bool = False,
query_bias: bool = True,
key_bias: bool = True,
value_bias: bool = True,
gradient_checkpointing: bool = False,
use_sdpa: bool = False,
layer_norm_type: str = 'layer_norm',
norm_eps: float = 1e-5,
n_kv_head: Optional[int] = None,
head_dim: Optional[int] = None,
):
super().__init__(input_size, output_size, attention_heads,
cgmlp_linear_units, num_blocks, dropout_rate,
positional_dropout_rate, attention_dropout_rate,
input_layer, pos_enc_layer_type, True,
static_chunk_size, use_dynamic_chunk, global_cmvn,
use_dynamic_left_chunk, gradient_checkpointing,
use_sdpa, layer_norm_type, norm_eps)
encoder_selfattn_layer_args = (
attention_heads,
output_size,
attention_dropout_rate,
query_bias,
key_bias,
value_bias,
use_sdpa,
n_kv_head,
head_dim,
)
cgmlp_layer = ConvolutionalGatingMLP
cgmlp_layer_args = (
output_size,
cgmlp_linear_units,
cgmlp_conv_kernel,
dropout_rate,
use_linear_after_conv,
gate_activation,
causal,
)
if isinstance(stochastic_depth_rate, float):
stochastic_depth_rate = [stochastic_depth_rate] * num_blocks
if len(stochastic_depth_rate) != num_blocks:
raise ValueError(
f"Length of stochastic_depth_rate ({len(stochastic_depth_rate)}) "
f"should be equal to num_blocks ({num_blocks})")
if isinstance(cgmlp_weight, float):
cgmlp_weight = [cgmlp_weight] * num_blocks
if len(cgmlp_weight) != num_blocks:
raise ValueError(
f"Length of cgmlp_weight ({len(cgmlp_weight)}) should be equal to "
f"num_blocks ({num_blocks})")
if isinstance(attn_branch_drop_rate, float):
attn_branch_drop_rate = [attn_branch_drop_rate] * num_blocks
if len(attn_branch_drop_rate) != num_blocks:
raise ValueError(
f"Length of attn_branch_drop_rate ({len(attn_branch_drop_rate)}) "
f"should be equal to num_blocks ({num_blocks})")
self.encoders = LayerDropModuleList(
p=stochastic_depth_rate,
modules=[
BranchformerEncoderLayer(
output_size,
WENET_ATTENTION_CLASSES[selfattention_layer_type](
*encoder_selfattn_layer_args) if use_attn else None,
cgmlp_layer(*cgmlp_layer_args) if use_cgmlp else None,
dropout_rate,
merge_method,
cgmlp_weight[lnum],
attn_branch_drop_rate[lnum],
stochastic_depth_rate[lnum],
) for lnum in range(num_blocks)
])
# modify from : https://github.com/facebookresearch/fairseq/blob/main/fairseq/modules/layer_drop.py # noqa
class LayerDropModuleList(torch.nn.ModuleList):
"""
A LayerDrop implementation based on :class:`torch.nn.ModuleList`.
We refresh the choice of which layers to drop every time we iterate
over the LayerDropModuleList instance. During evaluation we always
iterate over all layers.
Usage::
layers = LayerDropList(p=0.5, modules=[layer1, layer2, layer3])
for layer in layers: # this might iterate over layers 1 and 3
x = layer(x)
for layer in layers: # this might iterate over all layers
x = layer(x)
for layer in layers: # this might not iterate over any layers
x = layer(x)
Args:
p (float): probability of dropping out each layer
modules (iterable, optional): an iterable of modules to add
Limitations:
1 can work with ddp when layer's gradient checkpoint disabled
2 can't work with ddp when layer's gradient checkpoint enables
3 can work with fsdp
4 can work with deepspeed
"""
def __init__(self, p: List[float], modules=None):
super().__init__(modules)
assert len(p) == len(self)
self.p = p
def __iter__(self):
dropout_probs = torch.empty(len(self)).uniform_()
for i, m in enumerate(super().__iter__()):
if not self.training or (dropout_probs[i] > self.p[i]):
yield m