# Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu) # 2022 Xingchen Song (sxc19@mails.tsinghua.edu.cn) # 2023 NetEase Inc # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Modified from ESPnet(https://github.com/espnet/espnet) """Encoder definition.""" from typing import Optional, Tuple import torch from wenet.utils.mask import make_pad_mask from wenet.transformer.encoder import TransformerEncoder, ConformerEncoder class DualTransformerEncoder(TransformerEncoder): """Transformer encoder module.""" def __init__( self, input_size: int, output_size: int = 256, attention_heads: int = 4, linear_units: int = 2048, num_blocks: int = 6, dropout_rate: float = 0.1, positional_dropout_rate: float = 0.1, attention_dropout_rate: float = 0.0, input_layer: str = "conv2d", pos_enc_layer_type: str = "abs_pos", normalize_before: bool = True, static_chunk_size: int = 0, use_dynamic_chunk: bool = False, global_cmvn: torch.nn.Module = None, use_dynamic_left_chunk: bool = False, query_bias: bool = True, key_bias: bool = True, value_bias: bool = True, activation_type: str = "relu", gradient_checkpointing: bool = False, use_sdpa: bool = False, layer_norm_type: str = 'layer_norm', norm_eps: float = 1e-5, n_kv_head: Optional[int] = None, head_dim: Optional[int] = None, selfattention_layer_type: str = "selfattn", mlp_type: str = 'position_wise_feed_forward', mlp_bias: bool = True, n_expert: int = 8, n_expert_activated: int = 2, ): """ Construct DualTransformerEncoder Support both the full context mode and the streaming mode separately """ super().__init__(input_size, output_size, attention_heads, linear_units, num_blocks, dropout_rate, positional_dropout_rate, attention_dropout_rate, input_layer, pos_enc_layer_type, normalize_before, static_chunk_size, use_dynamic_chunk, global_cmvn, use_dynamic_left_chunk, query_bias, key_bias, value_bias, activation_type, gradient_checkpointing, use_sdpa, layer_norm_type, norm_eps, n_kv_head, head_dim, selfattention_layer_type, mlp_type, mlp_bias, n_expert, n_expert_activated) def forward_full( self, xs: torch.Tensor, xs_lens: torch.Tensor, decoding_chunk_size: int = 0, num_decoding_left_chunks: int = -1, ) -> Tuple[torch.Tensor, torch.Tensor]: T = xs.size(1) masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T) if self.global_cmvn is not None: xs = self.global_cmvn(xs) xs, pos_emb, masks = self.embed(xs, masks) mask_pad = masks # (B, 1, T/subsample_rate) for layer in self.encoders: xs, masks, _, _ = layer(xs, masks, pos_emb, mask_pad) if self.normalize_before: xs = self.after_norm(xs) return xs, masks class DualConformerEncoder(ConformerEncoder): """Conformer encoder module.""" def __init__( self, input_size: int, output_size: int = 256, attention_heads: int = 4, linear_units: int = 2048, num_blocks: int = 6, dropout_rate: float = 0.1, positional_dropout_rate: float = 0.1, attention_dropout_rate: float = 0.0, input_layer: str = "conv2d", pos_enc_layer_type: str = "rel_pos", normalize_before: bool = True, static_chunk_size: int = 0, use_dynamic_chunk: bool = False, global_cmvn: torch.nn.Module = None, use_dynamic_left_chunk: bool = False, positionwise_conv_kernel_size: int = 1, macaron_style: bool = True, selfattention_layer_type: str = "rel_selfattn", activation_type: str = "swish", use_cnn_module: bool = True, cnn_module_kernel: int = 15, causal: bool = False, cnn_module_norm: str = "batch_norm", query_bias: bool = True, key_bias: bool = True, value_bias: bool = True, conv_bias: bool = True, gradient_checkpointing: bool = False, use_sdpa: bool = False, layer_norm_type: str = 'layer_norm', norm_eps: float = 1e-5, n_kv_head: Optional[int] = None, head_dim: Optional[int] = None, mlp_type: str = 'position_wise_feed_forward', mlp_bias: bool = True, n_expert: int = 8, n_expert_activated: int = 2, ): """ Construct DualConformerEncoder Support both the full context mode and the streaming mode separately """ super().__init__( input_size, output_size, attention_heads, linear_units, num_blocks, dropout_rate, positional_dropout_rate, attention_dropout_rate, input_layer, pos_enc_layer_type, normalize_before, static_chunk_size, use_dynamic_chunk, global_cmvn, use_dynamic_left_chunk, positionwise_conv_kernel_size, macaron_style, selfattention_layer_type, activation_type, use_cnn_module, cnn_module_kernel, causal, cnn_module_norm, query_bias, key_bias, value_bias, conv_bias, gradient_checkpointing, use_sdpa, layer_norm_type, norm_eps, n_kv_head, head_dim, mlp_type, mlp_bias, n_expert, n_expert_activated) def forward_full( self, xs: torch.Tensor, xs_lens: torch.Tensor, decoding_chunk_size: int = 0, num_decoding_left_chunks: int = -1, ) -> Tuple[torch.Tensor, torch.Tensor]: T = xs.size(1) masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T) if self.global_cmvn is not None: xs = self.global_cmvn(xs) xs, pos_emb, masks = self.embed(xs, masks) mask_pad = masks # (B, 1, T/subsample_rate) for layer in self.encoders: xs, masks, _, _ = layer(xs, masks, pos_emb, mask_pad) if self.normalize_before: xs = self.after_norm(xs) return xs, masks