|
from torch import nn |
|
from .espnet_positional_embedding import RelPositionalEncoding |
|
from .espnet_transformer_attn import RelPositionMultiHeadedAttention |
|
from .layers import Swish, ConvolutionModule, EncoderLayer, MultiLayeredConv1d |
|
from ..layers import Embedding |
|
|
|
|
|
class ConformerLayers(nn.Module): |
|
def __init__(self, hidden_size, num_layers, kernel_size=9, dropout=0.0, num_heads=4, |
|
use_last_norm=True, save_hidden=False): |
|
super().__init__() |
|
self.use_last_norm = use_last_norm |
|
self.layers = nn.ModuleList() |
|
positionwise_layer = MultiLayeredConv1d |
|
positionwise_layer_args = (hidden_size, hidden_size * 4, 1, dropout) |
|
self.pos_embed = RelPositionalEncoding(hidden_size, dropout) |
|
self.encoder_layers = nn.ModuleList([EncoderLayer( |
|
hidden_size, |
|
RelPositionMultiHeadedAttention(num_heads, hidden_size, 0.0), |
|
positionwise_layer(*positionwise_layer_args), |
|
positionwise_layer(*positionwise_layer_args), |
|
ConvolutionModule(hidden_size, kernel_size, Swish()), |
|
dropout, |
|
) for _ in range(num_layers)]) |
|
if self.use_last_norm: |
|
self.layer_norm = nn.LayerNorm(hidden_size) |
|
else: |
|
self.layer_norm = nn.Linear(hidden_size, hidden_size) |
|
self.save_hidden = save_hidden |
|
if save_hidden: |
|
self.hiddens = [] |
|
|
|
def forward(self, x, padding_mask=None): |
|
""" |
|
|
|
:param x: [B, T, H] |
|
:param padding_mask: [B, T] |
|
:return: [B, T, H] |
|
""" |
|
self.hiddens = [] |
|
nonpadding_mask = x.abs().sum(-1) > 0 |
|
x = self.pos_embed(x) |
|
for l in self.encoder_layers: |
|
x, mask = l(x, nonpadding_mask[:, None, :]) |
|
if self.save_hidden: |
|
self.hiddens.append(x[0]) |
|
x = x[0] |
|
x = self.layer_norm(x) * nonpadding_mask.float()[:, :, None] |
|
return x |
|
|
|
|
|
class ConformerEncoder(ConformerLayers): |
|
def __init__(self, hidden_size, dict_size, num_layers=None): |
|
conformer_enc_kernel_size = 9 |
|
super().__init__(hidden_size, num_layers, conformer_enc_kernel_size) |
|
self.embed = Embedding(dict_size, hidden_size, padding_idx=0) |
|
|
|
def forward(self, x): |
|
""" |
|
|
|
:param src_tokens: [B, T] |
|
:return: [B x T x C] |
|
""" |
|
x = self.embed(x) |
|
x = super(ConformerEncoder, self).forward(x) |
|
return x |
|
|
|
|
|
class ConformerDecoder(ConformerLayers): |
|
def __init__(self, hidden_size, num_layers): |
|
conformer_dec_kernel_size = 9 |
|
super().__init__(hidden_size, num_layers, conformer_dec_kernel_size) |
|
|