|
import torch |
|
import torch.nn as nn |
|
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence |
|
|
|
|
|
def downsample(x, x_len, sample_rate, sample_style): |
|
batch_size, timestep, feature_dim = x.shape |
|
x_len = x_len // sample_rate |
|
|
|
if sample_style == 'drop': |
|
|
|
x = x[:, ::sample_rate, :].contiguous() |
|
elif sample_style == 'concat': |
|
|
|
if timestep % sample_rate != 0: |
|
x = x[:, :-(timestep % sample_rate), :] |
|
x = x.contiguous().view(batch_size, int( |
|
timestep / sample_rate), feature_dim * sample_rate) |
|
else: |
|
raise NotImplementedError |
|
|
|
return x, x_len |
|
|
|
|
|
class RNNLayer(nn.Module): |
|
''' RNN wrapper, includes time-downsampling''' |
|
|
|
def __init__(self, input_dim, module, bidirection, dim, dropout, layer_norm, sample_rate, proj): |
|
super(RNNLayer, self).__init__() |
|
|
|
rnn_out_dim = 2 * dim if bidirection else dim |
|
self.out_dim = rnn_out_dim |
|
self.dropout = dropout |
|
self.layer_norm = layer_norm |
|
self.sample_rate = sample_rate |
|
self.proj = proj |
|
|
|
|
|
self.layer = getattr(nn, module.upper())( |
|
input_dim, dim, bidirectional=bidirection, num_layers=1, batch_first=True) |
|
|
|
|
|
if self.layer_norm: |
|
self.ln = nn.LayerNorm(rnn_out_dim) |
|
if self.dropout > 0: |
|
self.dp = nn.Dropout(p=dropout) |
|
|
|
|
|
if self.proj: |
|
self.pj = nn.Linear(rnn_out_dim, rnn_out_dim) |
|
|
|
def forward(self, input_x, x_len): |
|
|
|
if not self.training: |
|
self.layer.flatten_parameters() |
|
|
|
input_x = pack_padded_sequence(input_x, x_len, batch_first=True, enforce_sorted=False) |
|
output, _ = self.layer(input_x) |
|
output, x_len = pad_packed_sequence(output, batch_first=True) |
|
|
|
|
|
if self.layer_norm: |
|
output = self.ln(output) |
|
if self.dropout > 0: |
|
output = self.dp(output) |
|
|
|
|
|
if self.sample_rate > 1: |
|
output, x_len = downsample(output, x_len, self.sample_rate, 'drop') |
|
|
|
if self.proj: |
|
output = torch.tanh(self.pj(output)) |
|
|
|
return output, x_len |
|
|
|
|
|
class RNNs(nn.Module): |
|
def __init__(self, |
|
input_size, |
|
output_size, |
|
upstream_rate, |
|
module, |
|
bidirection, |
|
dim, |
|
dropout, |
|
layer_norm, |
|
proj, |
|
sample_rate, |
|
sample_style, |
|
total_rate = 320, |
|
): |
|
super(RNNs, self).__init__() |
|
latest_size = input_size |
|
|
|
self.sample_rate = 1 if total_rate == -1 else round(total_rate / upstream_rate) |
|
self.sample_style = sample_style |
|
if sample_style == 'concat': |
|
latest_size *= self.sample_rate |
|
|
|
self.rnns = nn.ModuleList() |
|
for i in range(len(dim)): |
|
rnn_layer = RNNLayer( |
|
latest_size, |
|
module, |
|
bidirection, |
|
dim[i], |
|
dropout[i], |
|
layer_norm[i], |
|
sample_rate[i], |
|
proj[i], |
|
) |
|
self.rnns.append(rnn_layer) |
|
latest_size = rnn_layer.out_dim |
|
|
|
self.linear = nn.Linear(latest_size, output_size) |
|
|
|
def forward(self, x, x_len): |
|
r""" |
|
Args: |
|
x (torch.Tensor): Tensor of dimension (batch_size, input_length, num_features). |
|
x_len (torch.IntTensor): Tensor of dimension (batch_size). |
|
Returns: |
|
Tensor: Predictor tensor of dimension (batch_size, input_length, number_of_classes). |
|
""" |
|
|
|
if self.sample_rate > 1: |
|
x, x_len = downsample(x, x_len, self.sample_rate, self.sample_style) |
|
|
|
for rnn in self.rnns: |
|
x, x_len = rnn(x, x_len) |
|
|
|
logits = self.linear(x) |
|
return logits, x_len |
|
|
|
|
|
class Wav2Letter(nn.Module): |
|
""" |
|
The Wav2Letter model modified from torchaudio.models.Wav2Letter which preserves |
|
total downsample rate given the different upstream downsample rate. |
|
""" |
|
|
|
def __init__(self, input_dim, output_dim, upstream_rate, total_rate=320, **kwargs): |
|
super(Wav2Letter, self).__init__() |
|
first_stride = 1 if total_rate == -1 else total_rate // upstream_rate |
|
self.downsample_rate = first_stride |
|
|
|
self.acoustic_model = nn.Sequential( |
|
nn.Conv1d(in_channels=input_dim, out_channels=250, kernel_size=48, stride=first_stride, padding=23), |
|
nn.ReLU(inplace=True), |
|
nn.Conv1d(in_channels=250, out_channels=250, kernel_size=7, stride=1, padding=3), |
|
nn.ReLU(inplace=True), |
|
nn.Conv1d(in_channels=250, out_channels=250, kernel_size=7, stride=1, padding=3), |
|
nn.ReLU(inplace=True), |
|
nn.Conv1d(in_channels=250, out_channels=250, kernel_size=7, stride=1, padding=3), |
|
nn.ReLU(inplace=True), |
|
nn.Conv1d(in_channels=250, out_channels=250, kernel_size=7, stride=1, padding=3), |
|
nn.ReLU(inplace=True), |
|
nn.Conv1d(in_channels=250, out_channels=250, kernel_size=7, stride=1, padding=3), |
|
nn.ReLU(inplace=True), |
|
nn.Conv1d(in_channels=250, out_channels=250, kernel_size=7, stride=1, padding=3), |
|
nn.ReLU(inplace=True), |
|
nn.Conv1d(in_channels=250, out_channels=250, kernel_size=7, stride=1, padding=3), |
|
nn.ReLU(inplace=True), |
|
nn.Conv1d(in_channels=250, out_channels=2000, kernel_size=32, stride=1, padding=16), |
|
nn.ReLU(inplace=True), |
|
nn.Conv1d(in_channels=2000, out_channels=2000, kernel_size=1, stride=1, padding=0), |
|
nn.ReLU(inplace=True), |
|
nn.Conv1d(in_channels=2000, out_channels=output_dim, kernel_size=1, stride=1, padding=0), |
|
nn.ReLU(inplace=True) |
|
) |
|
|
|
def forward(self, x, x_len): |
|
r""" |
|
Args: |
|
x (torch.Tensor): Tensor of dimension (batch_size, input_length, num_features). |
|
x_len (torch.IntTensor): Tensor of dimension (batch_size). |
|
Returns: |
|
Tensor: Predictor tensor of dimension (batch_size, input_length, number_of_classes). |
|
""" |
|
x = self.acoustic_model(x.transpose(1, 2).contiguous()) |
|
return x.transpose(1, 2).contiguous(), x_len // self.downsample_rate |
|
|