|
import torch |
|
from torch import nn |
|
import torch.nn.functional as F |
|
|
|
|
|
class PreNet(nn.Module): |
|
def __init__(self, in_dims, fc1_dims=256, fc2_dims=128, dropout=0.5): |
|
super().__init__() |
|
self.fc1 = nn.Linear(in_dims, fc1_dims) |
|
self.fc2 = nn.Linear(fc1_dims, fc2_dims) |
|
self.p = dropout |
|
|
|
def forward(self, x): |
|
x = self.fc1(x) |
|
x = F.relu(x) |
|
x = F.dropout(x, self.p, training=self.training) |
|
x = self.fc2(x) |
|
x = F.relu(x) |
|
x = F.dropout(x, self.p, training=self.training) |
|
return x |
|
|
|
|
|
class HighwayNetwork(nn.Module): |
|
def __init__(self, size): |
|
super().__init__() |
|
self.W1 = nn.Linear(size, size) |
|
self.W2 = nn.Linear(size, size) |
|
self.W1.bias.data.fill_(0.) |
|
|
|
def forward(self, x): |
|
x1 = self.W1(x) |
|
x2 = self.W2(x) |
|
g = torch.sigmoid(x2) |
|
y = g * F.relu(x1) + (1. - g) * x |
|
return y |
|
|
|
|
|
class BatchNormConv(nn.Module): |
|
def __init__(self, in_channels, out_channels, kernel, relu=True): |
|
super().__init__() |
|
self.conv = nn.Conv1d(in_channels, out_channels, kernel, stride=1, padding=kernel // 2, bias=False) |
|
self.bnorm = nn.BatchNorm1d(out_channels) |
|
self.relu = relu |
|
|
|
def forward(self, x): |
|
x = self.conv(x) |
|
x = F.relu(x) if self.relu is True else x |
|
return self.bnorm(x) |
|
|
|
|
|
class ConvNorm(torch.nn.Module): |
|
def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, |
|
padding=None, dilation=1, bias=True, w_init_gain='linear'): |
|
super(ConvNorm, self).__init__() |
|
if padding is None: |
|
assert (kernel_size % 2 == 1) |
|
padding = int(dilation * (kernel_size - 1) / 2) |
|
|
|
self.conv = torch.nn.Conv1d(in_channels, out_channels, |
|
kernel_size=kernel_size, stride=stride, |
|
padding=padding, dilation=dilation, |
|
bias=bias) |
|
|
|
torch.nn.init.xavier_uniform_( |
|
self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain)) |
|
|
|
def forward(self, signal): |
|
conv_signal = self.conv(signal) |
|
return conv_signal |
|
|
|
|
|
class CBHG(nn.Module): |
|
def __init__(self, K, in_channels, channels, proj_channels, num_highways): |
|
super().__init__() |
|
|
|
|
|
self._to_flatten = [] |
|
|
|
self.bank_kernels = [i for i in range(1, K + 1)] |
|
self.conv1d_bank = nn.ModuleList() |
|
for k in self.bank_kernels: |
|
conv = BatchNormConv(in_channels, channels, k) |
|
self.conv1d_bank.append(conv) |
|
|
|
self.maxpool = nn.MaxPool1d(kernel_size=2, stride=1, padding=1) |
|
|
|
self.conv_project1 = BatchNormConv(len(self.bank_kernels) * channels, proj_channels[0], 3) |
|
self.conv_project2 = BatchNormConv(proj_channels[0], proj_channels[1], 3, relu=False) |
|
|
|
|
|
if proj_channels[-1] != channels: |
|
self.highway_mismatch = True |
|
self.pre_highway = nn.Linear(proj_channels[-1], channels, bias=False) |
|
else: |
|
self.highway_mismatch = False |
|
|
|
self.highways = nn.ModuleList() |
|
for i in range(num_highways): |
|
hn = HighwayNetwork(channels) |
|
self.highways.append(hn) |
|
|
|
self.rnn = nn.GRU(channels, channels, batch_first=True, bidirectional=True) |
|
self._to_flatten.append(self.rnn) |
|
|
|
|
|
self._flatten_parameters() |
|
|
|
def forward(self, x): |
|
|
|
|
|
|
|
self._flatten_parameters() |
|
|
|
|
|
residual = x |
|
seq_len = x.size(-1) |
|
conv_bank = [] |
|
|
|
|
|
for conv in self.conv1d_bank: |
|
c = conv(x) |
|
conv_bank.append(c[:, :, :seq_len]) |
|
|
|
|
|
conv_bank = torch.cat(conv_bank, dim=1) |
|
|
|
|
|
x = self.maxpool(conv_bank)[:, :, :seq_len] |
|
|
|
|
|
x = self.conv_project1(x) |
|
x = self.conv_project2(x) |
|
|
|
|
|
x = x + residual |
|
|
|
|
|
x = x.transpose(1, 2) |
|
if self.highway_mismatch is True: |
|
x = self.pre_highway(x) |
|
for h in self.highways: |
|
x = h(x) |
|
|
|
|
|
x, _ = self.rnn(x) |
|
return x |
|
|
|
def _flatten_parameters(self): |
|
"""Calls `flatten_parameters` on all the rnns used by the WaveRNN. Used |
|
to improve efficiency and avoid PyTorch yelling at us.""" |
|
[m.flatten_parameters() for m in self._to_flatten] |
|
|
|
|
|
class TacotronEncoder(nn.Module): |
|
def __init__(self, embed_dims, num_chars, cbhg_channels, K, num_highways, dropout): |
|
super().__init__() |
|
self.embedding = nn.Embedding(num_chars, embed_dims) |
|
self.pre_net = PreNet(embed_dims, embed_dims, embed_dims, dropout=dropout) |
|
self.cbhg = CBHG(K=K, in_channels=cbhg_channels, channels=cbhg_channels, |
|
proj_channels=[cbhg_channels, cbhg_channels], |
|
num_highways=num_highways) |
|
self.proj_out = nn.Linear(cbhg_channels * 2, cbhg_channels) |
|
|
|
def forward(self, x): |
|
x = self.embedding(x) |
|
x = self.pre_net(x) |
|
x.transpose_(1, 2) |
|
x = self.cbhg(x) |
|
x = self.proj_out(x) |
|
return x |
|
|
|
|
|
class RNNEncoder(nn.Module): |
|
def __init__(self, num_chars, embedding_dim, n_convolutions=3, kernel_size=5): |
|
super(RNNEncoder, self).__init__() |
|
self.embedding = nn.Embedding(num_chars, embedding_dim, padding_idx=0) |
|
convolutions = [] |
|
for _ in range(n_convolutions): |
|
conv_layer = nn.Sequential( |
|
ConvNorm(embedding_dim, |
|
embedding_dim, |
|
kernel_size=kernel_size, stride=1, |
|
padding=int((kernel_size - 1) / 2), |
|
dilation=1, w_init_gain='relu'), |
|
nn.BatchNorm1d(embedding_dim)) |
|
convolutions.append(conv_layer) |
|
self.convolutions = nn.ModuleList(convolutions) |
|
|
|
self.lstm = nn.LSTM(embedding_dim, int(embedding_dim / 2), 1, |
|
batch_first=True, bidirectional=True) |
|
|
|
def forward(self, x): |
|
input_lengths = (x > 0).sum(-1) |
|
input_lengths = input_lengths.cpu().numpy() |
|
|
|
x = self.embedding(x) |
|
x = x.transpose(1, 2) |
|
for conv in self.convolutions: |
|
x = F.dropout(F.relu(conv(x)), 0.5, self.training) + x |
|
x = x.transpose(1, 2) |
|
|
|
|
|
x = nn.utils.rnn.pack_padded_sequence(x, input_lengths, batch_first=True, enforce_sorted=False) |
|
|
|
self.lstm.flatten_parameters() |
|
outputs, _ = self.lstm(x) |
|
outputs, _ = nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True) |
|
|
|
return outputs |
|
|
|
|
|
class DecoderRNN(torch.nn.Module): |
|
def __init__(self, hidden_size, decoder_rnn_dim, dropout): |
|
super(DecoderRNN, self).__init__() |
|
self.in_conv1d = nn.Sequential( |
|
torch.nn.Conv1d( |
|
in_channels=hidden_size, |
|
out_channels=hidden_size, |
|
kernel_size=9, padding=4, |
|
), |
|
torch.nn.ReLU(), |
|
torch.nn.Conv1d( |
|
in_channels=hidden_size, |
|
out_channels=hidden_size, |
|
kernel_size=9, padding=4, |
|
), |
|
) |
|
self.ln = nn.LayerNorm(hidden_size) |
|
if decoder_rnn_dim == 0: |
|
decoder_rnn_dim = hidden_size * 2 |
|
self.rnn = torch.nn.LSTM( |
|
input_size=hidden_size, |
|
hidden_size=decoder_rnn_dim, |
|
num_layers=1, |
|
batch_first=True, |
|
bidirectional=True, |
|
dropout=dropout |
|
) |
|
self.rnn.flatten_parameters() |
|
self.conv1d = torch.nn.Conv1d( |
|
in_channels=decoder_rnn_dim * 2, |
|
out_channels=hidden_size, |
|
kernel_size=3, |
|
padding=1, |
|
) |
|
|
|
def forward(self, x): |
|
input_masks = x.abs().sum(-1).ne(0).data[:, :, None] |
|
input_lengths = input_masks.sum([-1, -2]) |
|
input_lengths = input_lengths.cpu().numpy() |
|
|
|
x = self.in_conv1d(x.transpose(1, 2)).transpose(1, 2) |
|
x = self.ln(x) |
|
x = nn.utils.rnn.pack_padded_sequence(x, input_lengths, batch_first=True, enforce_sorted=False) |
|
self.rnn.flatten_parameters() |
|
x, _ = self.rnn(x) |
|
x, _ = nn.utils.rnn.pad_packed_sequence(x, batch_first=True) |
|
x = x * input_masks |
|
pre_mel = self.conv1d(x.transpose(1, 2)).transpose(1, 2) |
|
pre_mel = pre_mel * input_masks |
|
return pre_mel |
|
|