Spaces:
Runtime error
Runtime error
import torch | |
import numpy as np | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from ..hparams import hparams as hp | |
from .global_style_token import GlobalStyleToken | |
from ..gst_hyperparameters import hparams as gst_hp | |
from ...log import logger | |
class HighwayNetwork(nn.Module): | |
def __init__(self, size): | |
super().__init__() | |
self.W1 = nn.Linear(size, size) | |
self.W2 = nn.Linear(size, size) | |
self.W1.bias.data.fill_(0.0) | |
def forward(self, x): | |
x1 = self.W1(x) | |
x2 = self.W2(x) | |
g = torch.sigmoid(x2) | |
y = g * F.relu(x1) + (1.0 - g) * x | |
return y | |
class Encoder(nn.Module): | |
def __init__(self, embed_dims, num_chars, encoder_dims, K, num_highways, dropout): | |
super().__init__() | |
prenet_dims = (encoder_dims, encoder_dims) | |
cbhg_channels = encoder_dims | |
self.embedding = nn.Embedding(num_chars, embed_dims) | |
self.pre_net = PreNet( | |
embed_dims, | |
fc1_dims=prenet_dims[0], | |
fc2_dims=prenet_dims[1], | |
dropout=dropout, | |
) | |
self.cbhg = CBHG( | |
K=K, | |
in_channels=cbhg_channels, | |
channels=cbhg_channels, | |
proj_channels=[cbhg_channels, cbhg_channels], | |
num_highways=num_highways, | |
) | |
def forward(self, x, speaker_embedding=None): | |
x = self.embedding(x) | |
x = self.pre_net(x) | |
x.transpose_(1, 2) | |
x = self.cbhg(x) | |
if speaker_embedding is not None: | |
x = self.add_speaker_embedding(x, speaker_embedding) | |
return x | |
def add_speaker_embedding(self, x, speaker_embedding): | |
# SV2TTS | |
# The input x is the encoder output and is a 3D tensor with size (batch_size, num_chars, tts_embed_dims) | |
# When training, speaker_embedding is also a 2D tensor with size (batch_size, speaker_embedding_size) | |
# (for inference, speaker_embedding is a 1D tensor with size (speaker_embedding_size)) | |
# This concats the speaker embedding for each char in the encoder output | |
# Save the dimensions as human-readable names | |
batch_size = x.size()[0] | |
num_chars = x.size()[1] | |
if speaker_embedding.dim() == 1: | |
idx = 0 | |
else: | |
idx = 1 | |
# Start by making a copy of each speaker embedding to match the input text length | |
# The output of this has size (batch_size, num_chars * speaker_embedding_size) | |
speaker_embedding_size = speaker_embedding.size()[idx] | |
e = speaker_embedding.repeat_interleave(num_chars, dim=idx) | |
# Reshape it and transpose | |
e = e.reshape(batch_size, speaker_embedding_size, num_chars) | |
e = e.transpose(1, 2) | |
# Concatenate the tiled speaker embedding with the encoder output | |
x = torch.cat((x, e), 2) | |
return x | |
class BatchNormConv(nn.Module): | |
def __init__(self, in_channels, out_channels, kernel, relu=True): | |
super().__init__() | |
self.conv = nn.Conv1d( | |
in_channels, out_channels, kernel, stride=1, padding=kernel // 2, bias=False | |
) | |
self.bnorm = nn.BatchNorm1d(out_channels) | |
self.relu = relu | |
def forward(self, x): | |
x = self.conv(x) | |
x = F.relu(x) if self.relu is True else x | |
return self.bnorm(x) | |
class CBHG(nn.Module): | |
def __init__(self, K, in_channels, channels, proj_channels, num_highways): | |
super().__init__() | |
# List of all rnns to call `flatten_parameters()` on | |
self._to_flatten = [] | |
self.bank_kernels = [i for i in range(1, K + 1)] | |
self.conv1d_bank = nn.ModuleList() | |
for k in self.bank_kernels: | |
conv = BatchNormConv(in_channels, channels, k) | |
self.conv1d_bank.append(conv) | |
self.maxpool = nn.MaxPool1d(kernel_size=2, stride=1, padding=1) | |
self.conv_project1 = BatchNormConv( | |
len(self.bank_kernels) * channels, proj_channels[0], 3 | |
) | |
self.conv_project2 = BatchNormConv( | |
proj_channels[0], proj_channels[1], 3, relu=False | |
) | |
# Fix the highway input if necessary | |
if proj_channels[-1] != channels: | |
self.highway_mismatch = True | |
self.pre_highway = nn.Linear(proj_channels[-1], channels, bias=False) | |
else: | |
self.highway_mismatch = False | |
self.highways = nn.ModuleList() | |
for i in range(num_highways): | |
hn = HighwayNetwork(channels) | |
self.highways.append(hn) | |
self.rnn = nn.GRU(channels, channels // 2, batch_first=True, bidirectional=True) | |
self._to_flatten.append(self.rnn) | |
# Avoid fragmentation of RNN parameters and associated warning | |
self._flatten_parameters() | |
def forward(self, x): | |
# Although we `_flatten_parameters()` on init, when using DataParallel | |
# the model gets replicated, making it no longer guaranteed that the | |
# weights are contiguous in GPU memory. Hence, we must call it again | |
self.rnn.flatten_parameters() | |
# Save these for later | |
residual = x | |
seq_len = x.size(-1) | |
conv_bank = [] | |
# Convolution Bank | |
for conv in self.conv1d_bank: | |
c = conv(x) # Convolution | |
conv_bank.append(c[:, :, :seq_len]) | |
# Stack along the channel axis | |
conv_bank = torch.cat(conv_bank, dim=1) | |
# dump the last padding to fit residual | |
x = self.maxpool(conv_bank)[:, :, :seq_len] | |
# Conv1d projections | |
x = self.conv_project1(x) | |
x = self.conv_project2(x) | |
# Residual Connect | |
x = x + residual | |
# Through the highways | |
x = x.transpose(1, 2) | |
if self.highway_mismatch is True: | |
x = self.pre_highway(x) | |
for h in self.highways: | |
x = h(x) | |
# And then the RNN | |
x, _ = self.rnn(x) | |
return x | |
def _flatten_parameters(self): | |
"""Calls `flatten_parameters` on all the rnns used by the WaveRNN. Used | |
to improve efficiency and avoid PyTorch yelling at us.""" | |
[m.flatten_parameters() for m in self._to_flatten] | |
class PreNet(nn.Module): | |
def __init__(self, in_dims, fc1_dims=256, fc2_dims=128, dropout=0.5): | |
super().__init__() | |
self.fc1 = nn.Linear(in_dims, fc1_dims) | |
self.fc2 = nn.Linear(fc1_dims, fc2_dims) | |
self.p = dropout | |
def forward(self, x): | |
x = self.fc1(x) | |
x = F.relu(x) | |
x = F.dropout(x, self.p, training=True) | |
x = self.fc2(x) | |
x = F.relu(x) | |
x = F.dropout(x, self.p, training=True) | |
return x | |
class Attention(nn.Module): | |
def __init__(self, attn_dims): | |
super().__init__() | |
self.W = nn.Linear(attn_dims, attn_dims, bias=False) | |
self.v = nn.Linear(attn_dims, 1, bias=False) | |
def forward(self, encoder_seq_proj, query, t): | |
# Transform the query vector | |
query_proj = self.W(query).unsqueeze(1) | |
# Compute the scores | |
u = self.v(torch.tanh(encoder_seq_proj + query_proj)) | |
scores = F.softmax(u, dim=1) | |
return scores.transpose(1, 2) | |
class LSA(nn.Module): | |
def __init__(self, attn_dim, kernel_size=31, filters=32): | |
super().__init__() | |
self.conv = nn.Conv1d( | |
1, | |
filters, | |
padding=(kernel_size - 1) // 2, | |
kernel_size=kernel_size, | |
bias=True, | |
) | |
self.L = nn.Linear(filters, attn_dim, bias=False) | |
self.W = nn.Linear( | |
attn_dim, attn_dim, bias=True | |
) # Include the attention bias in this term | |
self.v = nn.Linear(attn_dim, 1, bias=False) | |
self.cumulative = None | |
self.attention = None | |
def init_attention(self, encoder_seq_proj): | |
device = encoder_seq_proj.device # use same device as parameters | |
b, t, c = encoder_seq_proj.size() | |
self.cumulative = torch.zeros(b, t, device=device) | |
self.attention = torch.zeros(b, t, device=device) | |
def forward(self, encoder_seq_proj, query, t, chars): | |
if t == 0: | |
self.init_attention(encoder_seq_proj) | |
processed_query = self.W(query).unsqueeze(1) | |
location = self.cumulative.unsqueeze(1) | |
processed_loc = self.L(self.conv(location).transpose(1, 2)) | |
u = self.v(torch.tanh(processed_query + encoder_seq_proj + processed_loc)) | |
u = u.squeeze(-1) | |
# Mask zero padding chars | |
u = u * (chars != 0).float() | |
# Smooth Attention | |
# scores = torch.sigmoid(u) / torch.sigmoid(u).sum(dim=1, keepdim=True) | |
scores = F.softmax(u, dim=1) | |
self.attention = scores | |
self.cumulative = self.cumulative + self.attention | |
return scores.unsqueeze(-1).transpose(1, 2) | |
class Decoder(nn.Module): | |
# Class variable because its value doesn't change between classes | |
# yet ought to be scoped by class because its a property of a Decoder | |
max_r = 20 | |
def __init__( | |
self, | |
n_mels, | |
encoder_dims, | |
decoder_dims, | |
lstm_dims, | |
dropout, | |
speaker_embedding_size, | |
): | |
super().__init__() | |
self.register_buffer("r", torch.tensor(1, dtype=torch.int)) | |
self.n_mels = n_mels | |
prenet_dims = (decoder_dims * 2, decoder_dims * 2) | |
self.prenet = PreNet( | |
n_mels, fc1_dims=prenet_dims[0], fc2_dims=prenet_dims[1], dropout=dropout | |
) | |
self.attn_net = LSA(decoder_dims) | |
if hp.use_gst: | |
speaker_embedding_size += gst_hp.E | |
self.attn_rnn = nn.GRUCell( | |
encoder_dims + prenet_dims[1] + speaker_embedding_size, decoder_dims | |
) | |
self.rnn_input = nn.Linear( | |
encoder_dims + decoder_dims + speaker_embedding_size, lstm_dims | |
) | |
self.res_rnn1 = nn.LSTMCell(lstm_dims, lstm_dims) | |
self.res_rnn2 = nn.LSTMCell(lstm_dims, lstm_dims) | |
self.mel_proj = nn.Linear(lstm_dims, n_mels * self.max_r, bias=False) | |
self.stop_proj = nn.Linear(encoder_dims + speaker_embedding_size + lstm_dims, 1) | |
def zoneout(self, prev, current, device, p=0.1): | |
mask = torch.zeros(prev.size(), device=device).bernoulli_(p) | |
return prev * mask + current * (1 - mask) | |
def forward( | |
self, | |
encoder_seq, | |
encoder_seq_proj, | |
prenet_in, | |
hidden_states, | |
cell_states, | |
context_vec, | |
t, | |
chars, | |
): | |
# Need this for reshaping mels | |
batch_size = encoder_seq.size(0) | |
device = encoder_seq.device | |
# Unpack the hidden and cell states | |
attn_hidden, rnn1_hidden, rnn2_hidden = hidden_states | |
rnn1_cell, rnn2_cell = cell_states | |
# PreNet for the Attention RNN | |
prenet_out = self.prenet(prenet_in) | |
# Compute the Attention RNN hidden state | |
attn_rnn_in = torch.cat([context_vec, prenet_out], dim=-1) | |
attn_hidden = self.attn_rnn(attn_rnn_in.squeeze(1), attn_hidden) | |
# Compute the attention scores | |
scores = self.attn_net(encoder_seq_proj, attn_hidden, t, chars) | |
# Dot product to create the context vector | |
context_vec = scores @ encoder_seq | |
context_vec = context_vec.squeeze(1) | |
# Concat Attention RNN output w. Context Vector & project | |
x = torch.cat([context_vec, attn_hidden], dim=1) | |
x = self.rnn_input(x) | |
# Compute first Residual RNN | |
rnn1_hidden_next, rnn1_cell = self.res_rnn1(x, (rnn1_hidden, rnn1_cell)) | |
if self.training: | |
rnn1_hidden = self.zoneout(rnn1_hidden, rnn1_hidden_next, device=device) | |
else: | |
rnn1_hidden = rnn1_hidden_next | |
x = x + rnn1_hidden | |
# Compute second Residual RNN | |
rnn2_hidden_next, rnn2_cell = self.res_rnn2(x, (rnn2_hidden, rnn2_cell)) | |
if self.training: | |
rnn2_hidden = self.zoneout(rnn2_hidden, rnn2_hidden_next, device=device) | |
else: | |
rnn2_hidden = rnn2_hidden_next | |
x = x + rnn2_hidden | |
# Project Mels | |
mels = self.mel_proj(x) | |
mels = mels.view(batch_size, self.n_mels, self.max_r)[:, :, : self.r] | |
hidden_states = (attn_hidden, rnn1_hidden, rnn2_hidden) | |
cell_states = (rnn1_cell, rnn2_cell) | |
# Stop token prediction | |
s = torch.cat((x, context_vec), dim=1) | |
s = self.stop_proj(s) | |
stop_tokens = torch.sigmoid(s) | |
return mels, scores, hidden_states, cell_states, context_vec, stop_tokens | |
class Tacotron(nn.Module): | |
def __init__( | |
self, | |
embed_dims, | |
num_chars, | |
encoder_dims, | |
decoder_dims, | |
n_mels, | |
fft_bins, | |
postnet_dims, | |
encoder_K, | |
lstm_dims, | |
postnet_K, | |
num_highways, | |
dropout, | |
stop_threshold, | |
speaker_embedding_size, | |
): | |
super().__init__() | |
self.n_mels = n_mels | |
self.lstm_dims = lstm_dims | |
self.encoder_dims = encoder_dims | |
self.decoder_dims = decoder_dims | |
self.speaker_embedding_size = speaker_embedding_size | |
self.encoder = Encoder( | |
embed_dims, num_chars, encoder_dims, encoder_K, num_highways, dropout | |
) | |
project_dims = encoder_dims + speaker_embedding_size | |
if hp.use_gst: | |
project_dims += gst_hp.E | |
self.encoder_proj = nn.Linear(project_dims, decoder_dims, bias=False) | |
if hp.use_gst: | |
self.gst = GlobalStyleToken(speaker_embedding_size) | |
self.decoder = Decoder( | |
n_mels, | |
encoder_dims, | |
decoder_dims, | |
lstm_dims, | |
dropout, | |
speaker_embedding_size, | |
) | |
self.postnet = CBHG( | |
postnet_K, n_mels, postnet_dims, [postnet_dims, fft_bins], num_highways | |
) | |
self.post_proj = nn.Linear(postnet_dims, fft_bins, bias=False) | |
self.init_model() | |
self.num_params() | |
self.register_buffer("step", torch.zeros(1, dtype=torch.long)) | |
self.register_buffer( | |
"stop_threshold", torch.tensor(stop_threshold, dtype=torch.float32) | |
) | |
def r(self): | |
return self.decoder.r.item() | |
def r(self, value): | |
self.decoder.r = self.decoder.r.new_tensor(value, requires_grad=False) | |
def _concat_speaker_embedding(outputs, speaker_embeddings): | |
speaker_embeddings_ = speaker_embeddings.expand( | |
outputs.size(0), outputs.size(1), -1 | |
) | |
outputs = torch.cat([outputs, speaker_embeddings_], dim=-1) | |
return outputs | |
def forward(self, texts, mels, speaker_embedding): | |
device = texts.device # use same device as parameters | |
self.step += 1 | |
batch_size, _, steps = mels.size() | |
# Initialise all hidden states and pack into tuple | |
attn_hidden = torch.zeros(batch_size, self.decoder_dims, device=device) | |
rnn1_hidden = torch.zeros(batch_size, self.lstm_dims, device=device) | |
rnn2_hidden = torch.zeros(batch_size, self.lstm_dims, device=device) | |
hidden_states = (attn_hidden, rnn1_hidden, rnn2_hidden) | |
# Initialise all lstm cell states and pack into tuple | |
rnn1_cell = torch.zeros(batch_size, self.lstm_dims, device=device) | |
rnn2_cell = torch.zeros(batch_size, self.lstm_dims, device=device) | |
cell_states = (rnn1_cell, rnn2_cell) | |
# <GO> Frame for start of decoder loop | |
go_frame = torch.zeros(batch_size, self.n_mels, device=device) | |
# Need an initial context vector | |
size = self.encoder_dims + self.speaker_embedding_size | |
if hp.use_gst: | |
size += gst_hp.E | |
context_vec = torch.zeros(batch_size, size, device=device) | |
# SV2TTS: Run the encoder with the speaker embedding | |
# The projection avoids unnecessary matmuls in the decoder loop | |
encoder_seq = self.encoder(texts, speaker_embedding) | |
# put after encoder | |
if hp.use_gst and self.gst is not None: | |
style_embed = self.gst( | |
speaker_embedding, speaker_embedding | |
) # for training, speaker embedding can represent both style inputs and referenced | |
# style_embed = style_embed.expand_as(encoder_seq) | |
# encoder_seq = torch.cat((encoder_seq, style_embed), 2) | |
encoder_seq = self._concat_speaker_embedding(encoder_seq, style_embed) | |
encoder_seq_proj = self.encoder_proj(encoder_seq) | |
# Need a couple of lists for outputs | |
mel_outputs, attn_scores, stop_outputs = [], [], [] | |
# Run the decoder loop | |
for t in range(0, steps, self.r): | |
prenet_in = mels[:, :, t - 1] if t > 0 else go_frame | |
( | |
mel_frames, | |
scores, | |
hidden_states, | |
cell_states, | |
context_vec, | |
stop_tokens, | |
) = self.decoder( | |
encoder_seq, | |
encoder_seq_proj, | |
prenet_in, | |
hidden_states, | |
cell_states, | |
context_vec, | |
t, | |
texts, | |
) | |
mel_outputs.append(mel_frames) | |
attn_scores.append(scores) | |
stop_outputs.extend([stop_tokens] * self.r) | |
# Concat the mel outputs into sequence | |
mel_outputs = torch.cat(mel_outputs, dim=2) | |
# Post-Process for Linear Spectrograms | |
postnet_out = self.postnet(mel_outputs) | |
linear = self.post_proj(postnet_out) | |
linear = linear.transpose(1, 2) | |
# For easy visualisation | |
attn_scores = torch.cat(attn_scores, 1) | |
# attn_scores = attn_scores.cpu().data.numpy() | |
stop_outputs = torch.cat(stop_outputs, 1) | |
return mel_outputs, linear, attn_scores, stop_outputs | |
def generate( | |
self, x, speaker_embedding=None, steps=2000, style_idx=0, min_stop_token=5 | |
): | |
self.eval() | |
device = x.device # use same device as parameters | |
batch_size, _ = x.size() | |
# Need to initialise all hidden states and pack into tuple for tidyness | |
attn_hidden = torch.zeros(batch_size, self.decoder_dims, device=device) | |
rnn1_hidden = torch.zeros(batch_size, self.lstm_dims, device=device) | |
rnn2_hidden = torch.zeros(batch_size, self.lstm_dims, device=device) | |
hidden_states = (attn_hidden, rnn1_hidden, rnn2_hidden) | |
# Need to initialise all lstm cell states and pack into tuple for tidyness | |
rnn1_cell = torch.zeros(batch_size, self.lstm_dims, device=device) | |
rnn2_cell = torch.zeros(batch_size, self.lstm_dims, device=device) | |
cell_states = (rnn1_cell, rnn2_cell) | |
# Need a <GO> Frame for start of decoder loop | |
go_frame = torch.zeros(batch_size, self.n_mels, device=device) | |
# Need an initial context vector | |
size = self.encoder_dims + self.speaker_embedding_size | |
if hp.use_gst: | |
size += gst_hp.E | |
context_vec = torch.zeros(batch_size, size, device=device) | |
# SV2TTS: Run the encoder with the speaker embedding | |
# The projection avoids unnecessary matmuls in the decoder loop | |
encoder_seq = self.encoder(x, speaker_embedding) | |
# put after encoder | |
if hp.use_gst and self.gst is not None: | |
if style_idx >= 0 and style_idx < 10: | |
query = torch.zeros(1, 1, self.gst.stl.attention.num_units) | |
if device.type == "cuda": | |
query = query.cuda() | |
gst_embed = torch.tanh(self.gst.stl.embed) | |
key = gst_embed[style_idx].unsqueeze(0).expand(1, -1, -1) | |
style_embed = self.gst.stl.attention(query, key) | |
else: | |
speaker_embedding_style = torch.zeros( | |
speaker_embedding.size()[0], 1, self.speaker_embedding_size | |
).to(device) | |
style_embed = self.gst(speaker_embedding_style, speaker_embedding) | |
encoder_seq = self._concat_speaker_embedding(encoder_seq, style_embed) | |
# style_embed = style_embed.expand_as(encoder_seq) | |
# encoder_seq = torch.cat((encoder_seq, style_embed), 2) | |
encoder_seq_proj = self.encoder_proj(encoder_seq) | |
# Need a couple of lists for outputs | |
mel_outputs, attn_scores, stop_outputs = [], [], [] | |
# Run the decoder loop | |
for t in range(0, steps, self.r): | |
prenet_in = mel_outputs[-1][:, :, -1] if t > 0 else go_frame | |
( | |
mel_frames, | |
scores, | |
hidden_states, | |
cell_states, | |
context_vec, | |
stop_tokens, | |
) = self.decoder( | |
encoder_seq, | |
encoder_seq_proj, | |
prenet_in, | |
hidden_states, | |
cell_states, | |
context_vec, | |
t, | |
x, | |
) | |
mel_outputs.append(mel_frames) | |
attn_scores.append(scores) | |
stop_outputs.extend([stop_tokens] * self.r) | |
# Stop the loop when all stop tokens in batch exceed threshold | |
if (stop_tokens * 10 > min_stop_token).all() and t > 10: | |
break | |
# Concat the mel outputs into sequence | |
mel_outputs = torch.cat(mel_outputs, dim=2) | |
# Post-Process for Linear Spectrograms | |
postnet_out = self.postnet(mel_outputs) | |
linear = self.post_proj(postnet_out) | |
linear = linear.transpose(1, 2) | |
# For easy visualisation | |
attn_scores = torch.cat(attn_scores, 1) | |
stop_outputs = torch.cat(stop_outputs, 1) | |
self.train() | |
return mel_outputs, linear, attn_scores | |
def init_model(self): | |
for p in self.parameters(): | |
if p.dim() > 1: | |
nn.init.xavier_uniform_(p) | |
def finetune_partial(self, whitelist_layers): | |
self.zero_grad() | |
for name, child in self.named_children(): | |
if name in whitelist_layers: | |
logger.debug("Trainable Layer: %s" % name) | |
logger.debug( | |
"Trainable Parameters: %.3f" | |
% sum([np.prod(p.size()) for p in child.parameters()]) | |
) | |
for param in child.parameters(): | |
param.requires_grad = False | |
def get_step(self): | |
return self.step.data.item() | |
def reset_step(self): | |
# assignment to parameters or buffers is overloaded, updates internal dict entry | |
self.step = self.step.data.new_tensor(1) | |
def load(self, path, device, optimizer=None): | |
# Use device of model params as location for loaded state | |
checkpoint = torch.load(str(path), map_location=device) | |
self.load_state_dict(checkpoint["model_state"], strict=False) | |
if "optimizer_state" in checkpoint and optimizer is not None: | |
optimizer.load_state_dict(checkpoint["optimizer_state"]) | |
def save(self, path, optimizer=None): | |
if optimizer is not None: | |
torch.save( | |
{ | |
"model_state": self.state_dict(), | |
"optimizer_state": optimizer.state_dict(), | |
}, | |
str(path), | |
) | |
else: | |
torch.save( | |
{ | |
"model_state": self.state_dict(), | |
}, | |
str(path), | |
) | |
def num_params(self): | |
parameters = filter(lambda p: p.requires_grad, self.parameters()) | |
parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000 | |
logger.debug("Trainable Parameters: %.3fM" % parameters) | |
return parameters | |