Spaces:
Running
on
Zero
Running
on
Zero
#!/usr/bin/env python3 | |
# -*- coding: utf-8 -*- | |
# Copyright [2023-11-28] <[email protected], Xingchen Song> | |
import torch | |
from torch.nn import BatchNorm1d, LayerNorm | |
from wenet.paraformer.embedding import ParaformerPositinoalEncoding | |
from wenet.transformer.norm import RMSNorm | |
from wenet.transformer.positionwise_feed_forward import ( | |
GatedVariantsMLP, MoEFFNLayer, PositionwiseFeedForward) | |
from wenet.transformer.swish import Swish, New_gelu4npu | |
from wenet.transformer.subsampling import ( | |
LinearNoSubsampling, | |
EmbedinigNoSubsampling, | |
Conv1dSubsampling2, | |
Conv2dSubsampling4, | |
Conv2dSubsampling6, | |
Conv2dSubsampling8, | |
StackNFramesSubsampling, | |
) | |
from wenet.efficient_conformer.subsampling import Conv2dSubsampling2 | |
from wenet.squeezeformer.subsampling import DepthwiseConv2dSubsampling4 | |
from wenet.transformer.embedding import (PositionalEncoding, | |
RelPositionalEncoding, | |
RopePositionalEncoding, | |
WhisperPositionalEncoding, | |
LearnablePositionalEncoding, | |
NoPositionalEncoding) | |
from wenet.transformer.attention import (MultiHeadedAttention, | |
MultiHeadedCrossAttention, | |
RelPositionMultiHeadedAttention, | |
RopeMultiHeadedAttention, | |
ShawRelPositionMultiHeadedAttention) | |
from wenet.efficient_conformer.attention import ( | |
GroupedRelPositionMultiHeadedAttention) | |
WENET_ACTIVATION_CLASSES = { | |
"hardtanh": torch.nn.Hardtanh, | |
"tanh": torch.nn.Tanh, | |
"relu": torch.nn.ReLU, | |
"selu": torch.nn.SELU, | |
"swish": getattr(torch.nn, "SiLU", Swish), | |
"gelu": New_gelu4npu, | |
} | |
WENET_RNN_CLASSES = { | |
"rnn": torch.nn.RNN, | |
"lstm": torch.nn.LSTM, | |
"gru": torch.nn.GRU, | |
} | |
WENET_SUBSAMPLE_CLASSES = { | |
"linear": LinearNoSubsampling, | |
"embed": EmbedinigNoSubsampling, | |
"conv1d2": Conv1dSubsampling2, | |
"conv2d2": Conv2dSubsampling2, | |
"conv2d": Conv2dSubsampling4, | |
"dwconv2d4": DepthwiseConv2dSubsampling4, | |
"conv2d6": Conv2dSubsampling6, | |
"conv2d8": Conv2dSubsampling8, | |
'paraformer_dummy': torch.nn.Identity, | |
'stack_n_frames': StackNFramesSubsampling, | |
} | |
WENET_EMB_CLASSES = { | |
"embed": PositionalEncoding, | |
"abs_pos": PositionalEncoding, | |
"rel_pos": RelPositionalEncoding, | |
"no_pos": NoPositionalEncoding, | |
"abs_pos_whisper": WhisperPositionalEncoding, | |
"embed_learnable_pe": LearnablePositionalEncoding, | |
"abs_pos_paraformer": ParaformerPositinoalEncoding, | |
'rope_pos': RopePositionalEncoding, | |
} | |
WENET_ATTENTION_CLASSES = { | |
"selfattn": MultiHeadedAttention, | |
"rel_selfattn": RelPositionMultiHeadedAttention, | |
"grouped_rel_selfattn": GroupedRelPositionMultiHeadedAttention, | |
"crossattn": MultiHeadedCrossAttention, | |
'shaw_rel_selfattn': ShawRelPositionMultiHeadedAttention, | |
'rope_abs_selfattn': RopeMultiHeadedAttention, | |
} | |
WENET_MLP_CLASSES = { | |
'position_wise_feed_forward': PositionwiseFeedForward, | |
'moe': MoEFFNLayer, | |
'gated': GatedVariantsMLP | |
} | |
WENET_NORM_CLASSES = { | |
'layer_norm': LayerNorm, | |
'batch_norm': BatchNorm1d, | |
'rms_norm': RMSNorm | |
} | |