File size: 6,564 Bytes
0b32ad6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 |
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
def downsample(x, x_len, sample_rate, sample_style):
batch_size, timestep, feature_dim = x.shape
x_len = x_len // sample_rate
if sample_style == 'drop':
# Drop the unselected timesteps
x = x[:, ::sample_rate, :].contiguous()
elif sample_style == 'concat':
# Drop the redundant frames and concat the rest according to sample rate
if timestep % sample_rate != 0:
x = x[:, :-(timestep % sample_rate), :]
x = x.contiguous().view(batch_size, int(
timestep / sample_rate), feature_dim * sample_rate)
else:
raise NotImplementedError
return x, x_len
class RNNLayer(nn.Module):
''' RNN wrapper, includes time-downsampling'''
def __init__(self, input_dim, module, bidirection, dim, dropout, layer_norm, sample_rate, proj):
super(RNNLayer, self).__init__()
# Setup
rnn_out_dim = 2 * dim if bidirection else dim
self.out_dim = rnn_out_dim
self.dropout = dropout
self.layer_norm = layer_norm
self.sample_rate = sample_rate
self.proj = proj
# Recurrent layer
self.layer = getattr(nn, module.upper())(
input_dim, dim, bidirectional=bidirection, num_layers=1, batch_first=True)
# Regularizations
if self.layer_norm:
self.ln = nn.LayerNorm(rnn_out_dim)
if self.dropout > 0:
self.dp = nn.Dropout(p=dropout)
# Additional projection layer
if self.proj:
self.pj = nn.Linear(rnn_out_dim, rnn_out_dim)
def forward(self, input_x, x_len):
# Forward RNN
if not self.training:
self.layer.flatten_parameters()
input_x = pack_padded_sequence(input_x, x_len, batch_first=True, enforce_sorted=False)
output, _ = self.layer(input_x)
output, x_len = pad_packed_sequence(output, batch_first=True)
# Normalizations
if self.layer_norm:
output = self.ln(output)
if self.dropout > 0:
output = self.dp(output)
# Perform Downsampling
if self.sample_rate > 1:
output, x_len = downsample(output, x_len, self.sample_rate, 'drop')
if self.proj:
output = torch.tanh(self.pj(output))
return output, x_len
class RNNs(nn.Module):
def __init__(self,
input_size,
output_size,
upstream_rate,
module,
bidirection,
dim,
dropout,
layer_norm,
proj,
sample_rate,
sample_style,
total_rate = 320,
):
super(RNNs, self).__init__()
latest_size = input_size
self.sample_rate = 1 if total_rate == -1 else round(total_rate / upstream_rate)
self.sample_style = sample_style
if sample_style == 'concat':
latest_size *= self.sample_rate
self.rnns = nn.ModuleList()
for i in range(len(dim)):
rnn_layer = RNNLayer(
latest_size,
module,
bidirection,
dim[i],
dropout[i],
layer_norm[i],
sample_rate[i],
proj[i],
)
self.rnns.append(rnn_layer)
latest_size = rnn_layer.out_dim
self.linear = nn.Linear(latest_size, output_size)
def forward(self, x, x_len):
r"""
Args:
x (torch.Tensor): Tensor of dimension (batch_size, input_length, num_features).
x_len (torch.IntTensor): Tensor of dimension (batch_size).
Returns:
Tensor: Predictor tensor of dimension (batch_size, input_length, number_of_classes).
"""
# Perform Downsampling
if self.sample_rate > 1:
x, x_len = downsample(x, x_len, self.sample_rate, self.sample_style)
for rnn in self.rnns:
x, x_len = rnn(x, x_len)
logits = self.linear(x)
return logits, x_len
class Wav2Letter(nn.Module):
"""
The Wav2Letter model modified from torchaudio.models.Wav2Letter which preserves
total downsample rate given the different upstream downsample rate.
"""
def __init__(self, input_dim, output_dim, upstream_rate, total_rate=320, **kwargs):
super(Wav2Letter, self).__init__()
first_stride = 1 if total_rate == -1 else total_rate // upstream_rate
self.downsample_rate = first_stride
self.acoustic_model = nn.Sequential(
nn.Conv1d(in_channels=input_dim, out_channels=250, kernel_size=48, stride=first_stride, padding=23),
nn.ReLU(inplace=True),
nn.Conv1d(in_channels=250, out_channels=250, kernel_size=7, stride=1, padding=3),
nn.ReLU(inplace=True),
nn.Conv1d(in_channels=250, out_channels=250, kernel_size=7, stride=1, padding=3),
nn.ReLU(inplace=True),
nn.Conv1d(in_channels=250, out_channels=250, kernel_size=7, stride=1, padding=3),
nn.ReLU(inplace=True),
nn.Conv1d(in_channels=250, out_channels=250, kernel_size=7, stride=1, padding=3),
nn.ReLU(inplace=True),
nn.Conv1d(in_channels=250, out_channels=250, kernel_size=7, stride=1, padding=3),
nn.ReLU(inplace=True),
nn.Conv1d(in_channels=250, out_channels=250, kernel_size=7, stride=1, padding=3),
nn.ReLU(inplace=True),
nn.Conv1d(in_channels=250, out_channels=250, kernel_size=7, stride=1, padding=3),
nn.ReLU(inplace=True),
nn.Conv1d(in_channels=250, out_channels=2000, kernel_size=32, stride=1, padding=16),
nn.ReLU(inplace=True),
nn.Conv1d(in_channels=2000, out_channels=2000, kernel_size=1, stride=1, padding=0),
nn.ReLU(inplace=True),
nn.Conv1d(in_channels=2000, out_channels=output_dim, kernel_size=1, stride=1, padding=0),
nn.ReLU(inplace=True)
)
def forward(self, x, x_len):
r"""
Args:
x (torch.Tensor): Tensor of dimension (batch_size, input_length, num_features).
x_len (torch.IntTensor): Tensor of dimension (batch_size).
Returns:
Tensor: Predictor tensor of dimension (batch_size, input_length, number_of_classes).
"""
x = self.acoustic_model(x.transpose(1, 2).contiguous())
return x.transpose(1, 2).contiguous(), x_len // self.downsample_rate
|