lmzjms's picture
Upload 1162 files
0b32ad6 verified
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
def downsample(x, x_len, sample_rate, sample_style):
batch_size, timestep, feature_dim = x.shape
x_len = x_len // sample_rate
if sample_style == 'drop':
# Drop the unselected timesteps
x = x[:, ::sample_rate, :].contiguous()
elif sample_style == 'concat':
# Drop the redundant frames and concat the rest according to sample rate
if timestep % sample_rate != 0:
x = x[:, :-(timestep % sample_rate), :]
x = x.contiguous().view(batch_size, int(
timestep / sample_rate), feature_dim * sample_rate)
else:
raise NotImplementedError
return x, x_len
class RNNLayer(nn.Module):
''' RNN wrapper, includes time-downsampling'''
def __init__(self, input_dim, module, bidirection, dim, dropout, layer_norm, sample_rate, proj):
super(RNNLayer, self).__init__()
# Setup
rnn_out_dim = 2 * dim if bidirection else dim
self.out_dim = rnn_out_dim
self.dropout = dropout
self.layer_norm = layer_norm
self.sample_rate = sample_rate
self.proj = proj
# Recurrent layer
self.layer = getattr(nn, module.upper())(
input_dim, dim, bidirectional=bidirection, num_layers=1, batch_first=True)
# Regularizations
if self.layer_norm:
self.ln = nn.LayerNorm(rnn_out_dim)
if self.dropout > 0:
self.dp = nn.Dropout(p=dropout)
# Additional projection layer
if self.proj:
self.pj = nn.Linear(rnn_out_dim, rnn_out_dim)
def forward(self, input_x, x_len):
# Forward RNN
if not self.training:
self.layer.flatten_parameters()
input_x = pack_padded_sequence(input_x, x_len, batch_first=True, enforce_sorted=False)
output, _ = self.layer(input_x)
output, x_len = pad_packed_sequence(output, batch_first=True)
# Normalizations
if self.layer_norm:
output = self.ln(output)
if self.dropout > 0:
output = self.dp(output)
# Perform Downsampling
if self.sample_rate > 1:
output, x_len = downsample(output, x_len, self.sample_rate, 'drop')
if self.proj:
output = torch.tanh(self.pj(output))
return output, x_len
class RNNs(nn.Module):
def __init__(self,
input_size,
output_size,
upstream_rate,
module,
bidirection,
dim,
dropout,
layer_norm,
proj,
sample_rate,
sample_style,
total_rate = 320,
):
super(RNNs, self).__init__()
latest_size = input_size
self.sample_rate = 1 if total_rate == -1 else round(total_rate / upstream_rate)
self.sample_style = sample_style
if sample_style == 'concat':
latest_size *= self.sample_rate
self.rnns = nn.ModuleList()
for i in range(len(dim)):
rnn_layer = RNNLayer(
latest_size,
module,
bidirection,
dim[i],
dropout[i],
layer_norm[i],
sample_rate[i],
proj[i],
)
self.rnns.append(rnn_layer)
latest_size = rnn_layer.out_dim
self.linear = nn.Linear(latest_size, output_size)
def forward(self, x, x_len):
r"""
Args:
x (torch.Tensor): Tensor of dimension (batch_size, input_length, num_features).
x_len (torch.IntTensor): Tensor of dimension (batch_size).
Returns:
Tensor: Predictor tensor of dimension (batch_size, input_length, number_of_classes).
"""
# Perform Downsampling
if self.sample_rate > 1:
x, x_len = downsample(x, x_len, self.sample_rate, self.sample_style)
for rnn in self.rnns:
x, x_len = rnn(x, x_len)
logits = self.linear(x)
return logits, x_len
class Wav2Letter(nn.Module):
"""
The Wav2Letter model modified from torchaudio.models.Wav2Letter which preserves
total downsample rate given the different upstream downsample rate.
"""
def __init__(self, input_dim, output_dim, upstream_rate, total_rate=320, **kwargs):
super(Wav2Letter, self).__init__()
first_stride = 1 if total_rate == -1 else total_rate // upstream_rate
self.downsample_rate = first_stride
self.acoustic_model = nn.Sequential(
nn.Conv1d(in_channels=input_dim, out_channels=250, kernel_size=48, stride=first_stride, padding=23),
nn.ReLU(inplace=True),
nn.Conv1d(in_channels=250, out_channels=250, kernel_size=7, stride=1, padding=3),
nn.ReLU(inplace=True),
nn.Conv1d(in_channels=250, out_channels=250, kernel_size=7, stride=1, padding=3),
nn.ReLU(inplace=True),
nn.Conv1d(in_channels=250, out_channels=250, kernel_size=7, stride=1, padding=3),
nn.ReLU(inplace=True),
nn.Conv1d(in_channels=250, out_channels=250, kernel_size=7, stride=1, padding=3),
nn.ReLU(inplace=True),
nn.Conv1d(in_channels=250, out_channels=250, kernel_size=7, stride=1, padding=3),
nn.ReLU(inplace=True),
nn.Conv1d(in_channels=250, out_channels=250, kernel_size=7, stride=1, padding=3),
nn.ReLU(inplace=True),
nn.Conv1d(in_channels=250, out_channels=250, kernel_size=7, stride=1, padding=3),
nn.ReLU(inplace=True),
nn.Conv1d(in_channels=250, out_channels=2000, kernel_size=32, stride=1, padding=16),
nn.ReLU(inplace=True),
nn.Conv1d(in_channels=2000, out_channels=2000, kernel_size=1, stride=1, padding=0),
nn.ReLU(inplace=True),
nn.Conv1d(in_channels=2000, out_channels=output_dim, kernel_size=1, stride=1, padding=0),
nn.ReLU(inplace=True)
)
def forward(self, x, x_len):
r"""
Args:
x (torch.Tensor): Tensor of dimension (batch_size, input_length, num_features).
x_len (torch.IntTensor): Tensor of dimension (batch_size).
Returns:
Tensor: Predictor tensor of dimension (batch_size, input_length, number_of_classes).
"""
x = self.acoustic_model(x.transpose(1, 2).contiguous())
return x.transpose(1, 2).contiguous(), x_len // self.downsample_rate