OSUM / wenet /utils /class_utils.py
tomxxie
适配zeroGPU
568e264
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright [2023-11-28] <[email protected], Xingchen Song>
import torch
from torch.nn import BatchNorm1d, LayerNorm
from wenet.paraformer.embedding import ParaformerPositinoalEncoding
from wenet.transformer.norm import RMSNorm
from wenet.transformer.positionwise_feed_forward import (
GatedVariantsMLP, MoEFFNLayer, PositionwiseFeedForward)
from wenet.transformer.swish import Swish, New_gelu4npu
from wenet.transformer.subsampling import (
LinearNoSubsampling,
EmbedinigNoSubsampling,
Conv1dSubsampling2,
Conv2dSubsampling4,
Conv2dSubsampling6,
Conv2dSubsampling8,
StackNFramesSubsampling,
)
from wenet.efficient_conformer.subsampling import Conv2dSubsampling2
from wenet.squeezeformer.subsampling import DepthwiseConv2dSubsampling4
from wenet.transformer.embedding import (PositionalEncoding,
RelPositionalEncoding,
RopePositionalEncoding,
WhisperPositionalEncoding,
LearnablePositionalEncoding,
NoPositionalEncoding)
from wenet.transformer.attention import (MultiHeadedAttention,
MultiHeadedCrossAttention,
RelPositionMultiHeadedAttention,
RopeMultiHeadedAttention,
ShawRelPositionMultiHeadedAttention)
from wenet.efficient_conformer.attention import (
GroupedRelPositionMultiHeadedAttention)
WENET_ACTIVATION_CLASSES = {
"hardtanh": torch.nn.Hardtanh,
"tanh": torch.nn.Tanh,
"relu": torch.nn.ReLU,
"selu": torch.nn.SELU,
"swish": getattr(torch.nn, "SiLU", Swish),
"gelu": New_gelu4npu,
}
WENET_RNN_CLASSES = {
"rnn": torch.nn.RNN,
"lstm": torch.nn.LSTM,
"gru": torch.nn.GRU,
}
WENET_SUBSAMPLE_CLASSES = {
"linear": LinearNoSubsampling,
"embed": EmbedinigNoSubsampling,
"conv1d2": Conv1dSubsampling2,
"conv2d2": Conv2dSubsampling2,
"conv2d": Conv2dSubsampling4,
"dwconv2d4": DepthwiseConv2dSubsampling4,
"conv2d6": Conv2dSubsampling6,
"conv2d8": Conv2dSubsampling8,
'paraformer_dummy': torch.nn.Identity,
'stack_n_frames': StackNFramesSubsampling,
}
WENET_EMB_CLASSES = {
"embed": PositionalEncoding,
"abs_pos": PositionalEncoding,
"rel_pos": RelPositionalEncoding,
"no_pos": NoPositionalEncoding,
"abs_pos_whisper": WhisperPositionalEncoding,
"embed_learnable_pe": LearnablePositionalEncoding,
"abs_pos_paraformer": ParaformerPositinoalEncoding,
'rope_pos': RopePositionalEncoding,
}
WENET_ATTENTION_CLASSES = {
"selfattn": MultiHeadedAttention,
"rel_selfattn": RelPositionMultiHeadedAttention,
"grouped_rel_selfattn": GroupedRelPositionMultiHeadedAttention,
"crossattn": MultiHeadedCrossAttention,
'shaw_rel_selfattn': ShawRelPositionMultiHeadedAttention,
'rope_abs_selfattn': RopeMultiHeadedAttention,
}
WENET_MLP_CLASSES = {
'position_wise_feed_forward': PositionwiseFeedForward,
'moe': MoEFFNLayer,
'gated': GatedVariantsMLP
}
WENET_NORM_CLASSES = {
'layer_norm': LayerNorm,
'batch_norm': BatchNorm1d,
'rms_norm': RMSNorm
}