Spaces:
Runtime error
Runtime error
File size: 4,994 Bytes
7900c16 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
import torch
import torch.nn as nn
import math
class SpeechEmbedding(nn.Module):
"""
"""
def __init__(self, args, _):
super(SpeechEmbedding, self).__init__()
self.conv = Conv1dModule(args)
self.sinusoidalpos = False
self.emb_size = args.emb_size
if "sinusoidalpos" in args.embedding:
self.sinusoidalpos = True
def forward(self, src, _):
"""Embed inputs.
Args:
src (FloatTensor): Sequence of word vectors
``(batch_size, seq_len, self.dim)``
"""
speech_emb = self.conv(src)
if self.sinusoidalpos:
return speech_emb * math.sqrt(self.emb_size)
else:
return speech_emb
class Transpose_module(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return x.transpose(-2, -1)
class Conv1dModule(nn.Module):
"""
Convolutional subsampler: a stack of 1D convolution (along temporal dimension) followed by non-linear activation
via gated linear units (https://arxiv.org/abs/1911.08460)
"""
def __init__(self, args):
super(Conv1dModule, self).__init__()
self.embedding_dim = args.emb_size
self.norm_mode = None
self.feature_grad_mult = 1.0
self.conv_bias = True
self.dropout_input = 0.0
self.use_glu = True if args.data_processor == "s2t" else False
self.padding = True
self.conv_channels = args.conv_channels
self.audio_feature_size = args.audio_feature_size
self.kernel_sizes = args.conv_kernel_sizes
self.strides = [2 for _ in range(len(self.kernel_sizes))]
self.conv_layers = nn.ModuleList()
def conv_layer_block(
in_channels,
out_channels,
kernel_size,
stride,
padding,
norm_mode=None,
conv_bias=False,
):
def make_conv(in_channels, out_channels, kernel_size, stride, padding, conv_bias):
conv = nn.Conv1d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=conv_bias)
nn.init.kaiming_normal_(conv.weight)
return conv
if norm_mode == "layer":
return nn.Sequential(
make_conv(in_channels, out_channels, kernel_size, stride, padding, conv_bias),
Transpose_module(),
nn.LayerNorm(out_channels, eps=1e-5, elementwise_affine=True),
Transpose_module(),
nn.GELU(),
)
elif norm_mode == "group":
return nn.Sequential(
make_conv(in_channels, out_channels, kernel_size, stride, padding, conv_bias),
nn.GroupNorm(out_channels, out_channels, eps=1e-5, affine=True),
nn.GELU(),
)
elif self.use_glu:
return nn.Sequential(
make_conv(in_channels, out_channels, kernel_size, stride, padding, conv_bias),
)
else:
return nn.Sequential(
make_conv(in_channels, out_channels, kernel_size, stride, padding, conv_bias),
nn.GELU(),
)
assert len(self.strides) == len(self.kernel_sizes), "strides and kernel_sizes are not matched"
assert len(self.strides) == len(self.conv_channels), "strides and conv_channels are not matched"
in_channel = self.conv_channels[0] // 2
for i, (k, s, c) in enumerate(zip(self.kernel_sizes, self.strides, self.conv_channels)):
if self.audio_feature_size == 1:
in_channel = c
if self.norm_mode == "group" and i != 0:
self.norm_mode = None
if self.padding:
padding = k // 2
else:
padding = 0
self.conv_layers.append(
conv_layer_block(
self.audio_feature_size if i == 0 else in_channel,
c,
k,
s,
padding,
norm_mode=self.norm_mode,
conv_bias=self.conv_bias,
)
)
def forward(self, input_features, mask_indices=None, mask_channel_indices=None):
if len(input_features.size()) == 2:
hidden_states = input_features.unsqueeze(1) # wav B x T -> B x (C x D) x T
else:
hidden_states = input_features.transpose(1, 2).contiguous() #acoustic feature B x T x (C x D) -> B x (C x D) x T
for conv in self.conv_layers:
hidden_states = conv(hidden_states)
if self.use_glu:
hidden_states = nn.functional.glu(hidden_states, dim=1)
hidden_states = hidden_states.transpose(1, 2).contiguous() # -> B x T x (C x D)
return hidden_states
|