szukevin's picture
upload
7900c16
import torch
import torch.nn as nn
import math
class SpeechEmbedding(nn.Module):
"""
"""
def __init__(self, args, _):
super(SpeechEmbedding, self).__init__()
self.conv = Conv1dModule(args)
self.sinusoidalpos = False
self.emb_size = args.emb_size
if "sinusoidalpos" in args.embedding:
self.sinusoidalpos = True
def forward(self, src, _):
"""Embed inputs.
Args:
src (FloatTensor): Sequence of word vectors
``(batch_size, seq_len, self.dim)``
"""
speech_emb = self.conv(src)
if self.sinusoidalpos:
return speech_emb * math.sqrt(self.emb_size)
else:
return speech_emb
class Transpose_module(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return x.transpose(-2, -1)
class Conv1dModule(nn.Module):
"""
Convolutional subsampler: a stack of 1D convolution (along temporal dimension) followed by non-linear activation
via gated linear units (https://arxiv.org/abs/1911.08460)
"""
def __init__(self, args):
super(Conv1dModule, self).__init__()
self.embedding_dim = args.emb_size
self.norm_mode = None
self.feature_grad_mult = 1.0
self.conv_bias = True
self.dropout_input = 0.0
self.use_glu = True if args.data_processor == "s2t" else False
self.padding = True
self.conv_channels = args.conv_channels
self.audio_feature_size = args.audio_feature_size
self.kernel_sizes = args.conv_kernel_sizes
self.strides = [2 for _ in range(len(self.kernel_sizes))]
self.conv_layers = nn.ModuleList()
def conv_layer_block(
in_channels,
out_channels,
kernel_size,
stride,
padding,
norm_mode=None,
conv_bias=False,
):
def make_conv(in_channels, out_channels, kernel_size, stride, padding, conv_bias):
conv = nn.Conv1d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=conv_bias)
nn.init.kaiming_normal_(conv.weight)
return conv
if norm_mode == "layer":
return nn.Sequential(
make_conv(in_channels, out_channels, kernel_size, stride, padding, conv_bias),
Transpose_module(),
nn.LayerNorm(out_channels, eps=1e-5, elementwise_affine=True),
Transpose_module(),
nn.GELU(),
)
elif norm_mode == "group":
return nn.Sequential(
make_conv(in_channels, out_channels, kernel_size, stride, padding, conv_bias),
nn.GroupNorm(out_channels, out_channels, eps=1e-5, affine=True),
nn.GELU(),
)
elif self.use_glu:
return nn.Sequential(
make_conv(in_channels, out_channels, kernel_size, stride, padding, conv_bias),
)
else:
return nn.Sequential(
make_conv(in_channels, out_channels, kernel_size, stride, padding, conv_bias),
nn.GELU(),
)
assert len(self.strides) == len(self.kernel_sizes), "strides and kernel_sizes are not matched"
assert len(self.strides) == len(self.conv_channels), "strides and conv_channels are not matched"
in_channel = self.conv_channels[0] // 2
for i, (k, s, c) in enumerate(zip(self.kernel_sizes, self.strides, self.conv_channels)):
if self.audio_feature_size == 1:
in_channel = c
if self.norm_mode == "group" and i != 0:
self.norm_mode = None
if self.padding:
padding = k // 2
else:
padding = 0
self.conv_layers.append(
conv_layer_block(
self.audio_feature_size if i == 0 else in_channel,
c,
k,
s,
padding,
norm_mode=self.norm_mode,
conv_bias=self.conv_bias,
)
)
def forward(self, input_features, mask_indices=None, mask_channel_indices=None):
if len(input_features.size()) == 2:
hidden_states = input_features.unsqueeze(1) # wav B x T -> B x (C x D) x T
else:
hidden_states = input_features.transpose(1, 2).contiguous() #acoustic feature B x T x (C x D) -> B x (C x D) x T
for conv in self.conv_layers:
hidden_states = conv(hidden_states)
if self.use_glu:
hidden_states = nn.functional.glu(hidden_states, dim=1)
hidden_states = hidden_states.transpose(1, 2).contiguous() # -> B x T x (C x D)
return hidden_states