|
from math import sqrt |
|
import torch |
|
import torch.distributions as distr |
|
from torch.autograd import Variable |
|
from torch import nn |
|
from torch.nn import functional as F |
|
from .layers import ConvNorm, LinearNorm, GlobalAvgPool |
|
from .utils import to_gpu, get_mask_from_lengths |
|
|
|
|
|
class LocationLayer(nn.Module): |
|
def __init__(self, attention_n_filters, attention_kernel_size, |
|
attention_dim): |
|
super(LocationLayer, self).__init__() |
|
padding = int((attention_kernel_size - 1) / 2) |
|
self.location_conv = ConvNorm(2, attention_n_filters, |
|
kernel_size=attention_kernel_size, |
|
padding=padding, bias=False, stride=1, |
|
dilation=1) |
|
self.location_dense = LinearNorm(attention_n_filters, attention_dim, |
|
bias=False, w_init_gain='tanh') |
|
|
|
def forward(self, attention_weights_cat): |
|
processed_attention = self.location_conv(attention_weights_cat) |
|
processed_attention = processed_attention.transpose(1, 2) |
|
processed_attention = self.location_dense(processed_attention) |
|
return processed_attention |
|
|
|
|
|
class Attention(nn.Module): |
|
def __init__(self, attention_rnn_dim, embedding_dim, attention_dim, |
|
attention_location_n_filters, attention_location_kernel_size): |
|
super(Attention, self).__init__() |
|
self.query_layer = LinearNorm(attention_rnn_dim, attention_dim, |
|
bias=False, w_init_gain='tanh') |
|
self.memory_layer = LinearNorm(embedding_dim, attention_dim, bias=False, |
|
w_init_gain='tanh') |
|
self.v = LinearNorm(attention_dim, 1, bias=False) |
|
self.location_layer = LocationLayer(attention_location_n_filters, |
|
attention_location_kernel_size, |
|
attention_dim) |
|
self.score_mask_value = -float("inf") |
|
|
|
def get_alignment_energies(self, query, processed_memory, |
|
attention_weights_cat): |
|
""" |
|
PARAMS |
|
------ |
|
query: decoder output (batch, n_mel_channels * n_frames_per_step) |
|
processed_memory: processed encoder outputs (B, T_in, attention_dim) |
|
attention_weights_cat: cumulative and prev. att weights (B, 2, max_time) |
|
|
|
RETURNS |
|
------- |
|
alignment (batch, max_time) |
|
""" |
|
|
|
processed_query = self.query_layer(query.unsqueeze(1)) |
|
processed_attention_weights = self.location_layer(attention_weights_cat) |
|
energies = self.v(torch.tanh( |
|
processed_query + processed_attention_weights + processed_memory)) |
|
|
|
energies = energies.squeeze(-1) |
|
return energies |
|
|
|
def forward(self, attention_hidden_state, memory, processed_memory, |
|
attention_weights_cat, mask): |
|
""" |
|
PARAMS |
|
------ |
|
attention_hidden_state: attention rnn last output |
|
memory: encoder outputs |
|
processed_memory: processed encoder outputs |
|
attention_weights_cat: previous and cummulative attention weights |
|
mask: binary mask for padded data |
|
""" |
|
alignment = self.get_alignment_energies( |
|
attention_hidden_state, processed_memory, attention_weights_cat) |
|
|
|
if mask is not None: |
|
alignment.data.masked_fill_(mask, self.score_mask_value) |
|
|
|
attention_weights = F.softmax(alignment, dim=1) |
|
attention_context = torch.bmm(attention_weights.unsqueeze(1), memory) |
|
attention_context = attention_context.squeeze(1) |
|
|
|
return attention_context, attention_weights |
|
|
|
|
|
class Prenet(nn.Module): |
|
def __init__(self, in_dim, sizes): |
|
super(Prenet, self).__init__() |
|
in_sizes = [in_dim] + sizes[:-1] |
|
self.layers = nn.ModuleList( |
|
[LinearNorm(in_size, out_size, bias=False) |
|
for (in_size, out_size) in zip(in_sizes, sizes)]) |
|
|
|
def forward(self, x): |
|
for linear in self.layers: |
|
x = F.dropout(F.relu(linear(x)), p=0.5, training=True) |
|
return x |
|
|
|
|
|
class Postnet(nn.Module): |
|
"""Postnet |
|
- Five 1-d convolution with 512 channels and kernel size 5 |
|
""" |
|
|
|
def __init__(self, hparams): |
|
super(Postnet, self).__init__() |
|
self.convolutions = nn.ModuleList() |
|
|
|
self.convolutions.append( |
|
nn.Sequential( |
|
ConvNorm(hparams.n_mel_channels, hparams.postnet_embedding_dim, |
|
kernel_size=hparams.postnet_kernel_size, stride=1, |
|
padding=int((hparams.postnet_kernel_size - 1) / 2), |
|
dilation=1, w_init_gain='tanh'), |
|
nn.BatchNorm1d(hparams.postnet_embedding_dim)) |
|
) |
|
|
|
for i in range(1, hparams.postnet_n_convolutions - 1): |
|
self.convolutions.append( |
|
nn.Sequential( |
|
ConvNorm(hparams.postnet_embedding_dim, |
|
hparams.postnet_embedding_dim, |
|
kernel_size=hparams.postnet_kernel_size, stride=1, |
|
padding=int((hparams.postnet_kernel_size - 1) / 2), |
|
dilation=1, w_init_gain='tanh'), |
|
nn.BatchNorm1d(hparams.postnet_embedding_dim)) |
|
) |
|
|
|
self.convolutions.append( |
|
nn.Sequential( |
|
ConvNorm(hparams.postnet_embedding_dim, hparams.n_mel_channels, |
|
kernel_size=hparams.postnet_kernel_size, stride=1, |
|
padding=int((hparams.postnet_kernel_size - 1) / 2), |
|
dilation=1, w_init_gain='linear'), |
|
nn.BatchNorm1d(hparams.n_mel_channels)) |
|
) |
|
|
|
def forward(self, x): |
|
for i in range(len(self.convolutions) - 1): |
|
x = F.dropout(torch.tanh(self.convolutions[i](x)), 0.5, self.training) |
|
x = F.dropout(self.convolutions[-1](x), 0.5, self.training) |
|
|
|
return x |
|
|
|
|
|
class Encoder(nn.Module): |
|
"""Encoder module: |
|
- Three 1-d convolution banks |
|
- Bidirectional LSTM |
|
""" |
|
def __init__(self, hparams): |
|
super(Encoder, self).__init__() |
|
|
|
convolutions = [] |
|
for _ in range(hparams.encoder_n_convolutions): |
|
conv_layer = nn.Sequential( |
|
ConvNorm(hparams.encoder_embedding_dim, |
|
hparams.encoder_embedding_dim, |
|
kernel_size=hparams.encoder_kernel_size, stride=1, |
|
padding=int((hparams.encoder_kernel_size - 1) / 2), |
|
dilation=1, w_init_gain='relu'), |
|
nn.BatchNorm1d(hparams.encoder_embedding_dim)) |
|
convolutions.append(conv_layer) |
|
self.convolutions = nn.ModuleList(convolutions) |
|
|
|
self.lstm = nn.LSTM(hparams.encoder_embedding_dim, |
|
int(hparams.encoder_embedding_dim / 2), 1, |
|
batch_first=True, bidirectional=True) |
|
|
|
def forward(self, x, input_lengths): |
|
for conv in self.convolutions: |
|
x = F.dropout(F.relu(conv(x)), 0.5, self.training) |
|
|
|
x = x.transpose(1, 2) |
|
|
|
|
|
input_lengths = input_lengths.cpu().numpy() |
|
x = nn.utils.rnn.pack_padded_sequence( |
|
x, input_lengths, batch_first=True) |
|
|
|
self.lstm.flatten_parameters() |
|
outputs, _ = self.lstm(x) |
|
|
|
outputs, _ = nn.utils.rnn.pad_packed_sequence( |
|
outputs, batch_first=True) |
|
|
|
return outputs |
|
|
|
def inference(self, x): |
|
for conv in self.convolutions: |
|
x = F.dropout(F.relu(conv(x)), 0.5, self.training) |
|
|
|
x = x.transpose(1, 2) |
|
|
|
self.lstm.flatten_parameters() |
|
outputs, _ = self.lstm(x) |
|
|
|
return outputs |
|
|
|
|
|
class AudioEncoder(nn.Module): |
|
def __init__(self, hparams): |
|
super(AudioEncoder, self).__init__() |
|
|
|
assert hparams.lat_dim > 0 |
|
|
|
convolutions = [] |
|
inp_dim = hparams.n_mel_channels |
|
for _ in range(hparams.lat_n_convolutions): |
|
conv_layer = nn.Sequential( |
|
ConvNorm(inp_dim, hparams.lat_n_filters, |
|
kernel_size=hparams.lat_kernel_size, stride=1, |
|
padding=int((hparams.lat_kernel_size - 1) / 2), |
|
dilation=1, w_init_gain='tanh'), |
|
nn.BatchNorm1d(hparams.lat_n_filters)) |
|
inp_dim = hparams.lat_n_filters |
|
convolutions.append(conv_layer) |
|
self.convolutions = nn.ModuleList(convolutions) |
|
|
|
self.lstm = nn.LSTM(hparams.lat_n_filters, |
|
int(hparams.lat_n_filters / 2), |
|
hparams.lat_n_blstms, batch_first=True, |
|
bidirectional=True) |
|
self.pool = GlobalAvgPool() |
|
|
|
self.mu_proj = LinearNorm(hparams.lat_n_filters, hparams.lat_dim) |
|
self.logvar_proj = LinearNorm(hparams.lat_n_filters, hparams.lat_dim) |
|
self.lat_dim = hparams.lat_dim |
|
|
|
def forward(self, x, lengths): |
|
""" |
|
Args: |
|
x (torch.Tensor): (B, F, T) |
|
""" |
|
|
|
for conv in self.convolutions: |
|
x = F.dropout(F.tanh(conv(x)), 0.5, self.training) |
|
|
|
x = x.transpose(1, 2) |
|
|
|
|
|
max_len = x.size(1) |
|
assert max_len == torch.max(lengths).item() |
|
|
|
lengths, perm_idx = lengths.sort(0, descending=True) |
|
x = x[perm_idx] |
|
x = nn.utils.rnn.pack_padded_sequence(x, lengths, batch_first=True) |
|
|
|
self.lstm.flatten_parameters() |
|
outputs, _ = self.lstm(x) |
|
outputs, _ = nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True) |
|
|
|
_, unperm_idx = perm_idx.sort(0) |
|
outputs = outputs[unperm_idx] |
|
lengths = lengths[unperm_idx] |
|
|
|
outputs = self.pool(outputs, lengths) |
|
|
|
mu = self.mu_proj(outputs) |
|
logvar = self.logvar_proj(outputs) |
|
z = distr.Normal(mu, logvar).rsample() |
|
return z, mu, logvar |
|
|
|
|
|
class Decoder(nn.Module): |
|
def __init__(self, hparams): |
|
super(Decoder, self).__init__() |
|
self.n_mel_channels = hparams.n_mel_channels |
|
self.n_frames_per_step = hparams.n_frames_per_step |
|
self.encoder_embedding_dim = hparams.encoder_embedding_dim |
|
self.obs_dim = hparams.obs_dim |
|
self.lat_dim = hparams.lat_dim |
|
self.attention_rnn_dim = hparams.attention_rnn_dim |
|
self.decoder_rnn_dim = hparams.decoder_rnn_dim |
|
self.prenet_dim = hparams.prenet_dim |
|
self.max_decoder_steps = hparams.max_decoder_steps |
|
self.gate_threshold = hparams.gate_threshold |
|
self.p_attention_dropout = hparams.p_attention_dropout |
|
self.p_decoder_dropout = hparams.p_decoder_dropout |
|
|
|
self.prenet = Prenet( |
|
hparams.n_mel_channels * hparams.n_frames_per_step, |
|
[hparams.prenet_dim, hparams.prenet_dim]) |
|
|
|
self.attention_rnn = nn.LSTMCell( |
|
hparams.prenet_dim + hparams.encoder_embedding_dim, |
|
hparams.attention_rnn_dim) |
|
|
|
self.attention_layer = Attention( |
|
hparams.attention_rnn_dim, hparams.encoder_embedding_dim, |
|
hparams.attention_dim, hparams.attention_location_n_filters, |
|
hparams.attention_location_kernel_size) |
|
|
|
encoder_tot_dim = (hparams.encoder_embedding_dim + \ |
|
hparams.lat_dim + hparams.obs_dim) |
|
self.decoder_rnn = nn.LSTMCell( |
|
hparams.attention_rnn_dim + encoder_tot_dim, |
|
hparams.decoder_rnn_dim, 1) |
|
|
|
self.linear_projection = LinearNorm( |
|
hparams.decoder_rnn_dim + encoder_tot_dim, |
|
hparams.n_mel_channels * hparams.n_frames_per_step) |
|
|
|
self.gate_layer = LinearNorm( |
|
hparams.decoder_rnn_dim + encoder_tot_dim, 1, |
|
bias=True, w_init_gain='sigmoid') |
|
|
|
def get_go_frame(self, memory): |
|
""" Gets all zeros frames to use as first decoder input |
|
PARAMS |
|
------ |
|
memory: decoder outputs |
|
|
|
RETURNS |
|
------- |
|
decoder_input: all zeros frames |
|
""" |
|
B = memory.size(0) |
|
decoder_input = Variable(memory.data.new( |
|
B, self.n_mel_channels * self.n_frames_per_step).zero_()) |
|
return decoder_input |
|
|
|
def initialize_decoder_states(self, memory, obs_and_lat, mask): |
|
""" Initializes attention rnn states, decoder rnn states, attention |
|
weights, attention cumulative weights, attention context, stores memory |
|
and stores processed memory |
|
PARAMS |
|
------ |
|
memory: Encoder outputs |
|
obs_and_lat: Observed and latent attribute embeddings |
|
mask: Mask for padded data if training, expects None for inference |
|
""" |
|
B = memory.size(0) |
|
MAX_TIME = memory.size(1) |
|
|
|
self.attention_hidden = Variable(memory.data.new( |
|
B, self.attention_rnn_dim).zero_()) |
|
self.attention_cell = Variable(memory.data.new( |
|
B, self.attention_rnn_dim).zero_()) |
|
|
|
self.decoder_hidden = Variable(memory.data.new( |
|
B, self.decoder_rnn_dim).zero_()) |
|
self.decoder_cell = Variable(memory.data.new( |
|
B, self.decoder_rnn_dim).zero_()) |
|
|
|
self.attention_weights = Variable(memory.data.new( |
|
B, MAX_TIME).zero_()) |
|
self.attention_weights_cum = Variable(memory.data.new( |
|
B, MAX_TIME).zero_()) |
|
self.attention_context = Variable(memory.data.new( |
|
B, self.encoder_embedding_dim).zero_()) |
|
|
|
self.memory = memory |
|
self.processed_memory = self.attention_layer.memory_layer(memory) |
|
self.obs_and_lat = obs_and_lat |
|
self.mask = mask |
|
|
|
def parse_decoder_inputs(self, decoder_inputs): |
|
""" Prepares decoder inputs, i.e. mel outputs |
|
PARAMS |
|
------ |
|
decoder_inputs: inputs used for teacher-forced training, i.e. mel-specs |
|
|
|
RETURNS |
|
------- |
|
inputs: processed decoder inputs |
|
|
|
""" |
|
|
|
decoder_inputs = decoder_inputs.transpose(1, 2) |
|
decoder_inputs = decoder_inputs.view( |
|
decoder_inputs.size(0), |
|
int(decoder_inputs.size(1)/self.n_frames_per_step), -1) |
|
|
|
decoder_inputs = decoder_inputs.transpose(0, 1) |
|
return decoder_inputs |
|
|
|
def parse_decoder_outputs(self, mel_outputs, gate_outputs, alignments): |
|
""" Prepares decoder outputs for output |
|
PARAMS |
|
------ |
|
mel_outputs: |
|
gate_outputs: gate output energies |
|
alignments: |
|
|
|
RETURNS |
|
------- |
|
mel_outputs: |
|
gate_outpust: gate output energies |
|
alignments: |
|
""" |
|
|
|
alignments = torch.stack(alignments).transpose(0, 1) |
|
|
|
gate_outputs = torch.stack(gate_outputs).transpose(0, 1) |
|
gate_outputs = gate_outputs.contiguous() |
|
|
|
mel_outputs = torch.stack(mel_outputs).transpose(0, 1).contiguous() |
|
|
|
mel_outputs = mel_outputs.view( |
|
mel_outputs.size(0), -1, self.n_mel_channels) |
|
|
|
mel_outputs = mel_outputs.transpose(1, 2) |
|
|
|
return mel_outputs, gate_outputs, alignments |
|
|
|
def decode(self, decoder_input): |
|
""" Decoder step using stored states, attention and memory |
|
PARAMS |
|
------ |
|
decoder_input: previous mel output |
|
|
|
RETURNS |
|
------- |
|
mel_output: |
|
gate_output: gate output energies |
|
attention_weights: |
|
""" |
|
cell_input = torch.cat((decoder_input, self.attention_context), -1) |
|
self.attention_hidden, self.attention_cell = self.attention_rnn( |
|
cell_input, (self.attention_hidden, self.attention_cell)) |
|
self.attention_hidden = F.dropout( |
|
self.attention_hidden, self.p_attention_dropout, self.training) |
|
|
|
attention_weights_cat = torch.cat( |
|
(self.attention_weights.unsqueeze(1), |
|
self.attention_weights_cum.unsqueeze(1)), dim=1) |
|
self.attention_context, self.attention_weights = self.attention_layer( |
|
self.attention_hidden, self.memory, self.processed_memory, |
|
attention_weights_cat, self.mask) |
|
|
|
self.attention_weights_cum += self.attention_weights |
|
decoder_input = torch.cat( |
|
(self.attention_hidden, self.attention_context), -1) |
|
if self.obs_and_lat is not None: |
|
decoder_input = torch.cat((decoder_input, self.obs_and_lat), -1) |
|
self.decoder_hidden, self.decoder_cell = self.decoder_rnn( |
|
decoder_input, (self.decoder_hidden, self.decoder_cell)) |
|
self.decoder_hidden = F.dropout( |
|
self.decoder_hidden, self.p_decoder_dropout, self.training) |
|
|
|
decoder_hidden_attention_context = torch.cat( |
|
(self.decoder_hidden, self.attention_context), dim=1) |
|
if self.obs_and_lat is not None: |
|
decoder_hidden_attention_context = torch.cat( |
|
(decoder_hidden_attention_context, self.obs_and_lat), dim=1) |
|
decoder_output = self.linear_projection( |
|
decoder_hidden_attention_context) |
|
|
|
gate_prediction = self.gate_layer(decoder_hidden_attention_context) |
|
return decoder_output, gate_prediction, self.attention_weights |
|
|
|
def forward(self, memory, obs_and_lat, decoder_inputs, memory_lengths): |
|
""" Decoder forward pass for training |
|
PARAMS |
|
------ |
|
memory: Encoder outputs |
|
obs_and_lat: Observed and latent attribute embeddings |
|
decoder_inputs: Decoder inputs for teacher forcing. i.e. mel-specs |
|
memory_lengths: Encoder output lengths for attention masking. |
|
|
|
RETURNS |
|
------- |
|
mel_outputs: mel outputs from the decoder |
|
gate_outputs: gate outputs from the decoder |
|
alignments: sequence of attention weights from the decoder |
|
""" |
|
|
|
decoder_input = self.get_go_frame(memory).unsqueeze(0) |
|
decoder_inputs = self.parse_decoder_inputs(decoder_inputs) |
|
decoder_inputs = torch.cat((decoder_input, decoder_inputs), dim=0) |
|
decoder_inputs = self.prenet(decoder_inputs) |
|
|
|
self.initialize_decoder_states( |
|
memory, obs_and_lat, mask=~get_mask_from_lengths(memory_lengths)) |
|
|
|
mel_outputs, gate_outputs, alignments = [], [], [] |
|
while len(mel_outputs) < decoder_inputs.size(0) - 1: |
|
decoder_input = decoder_inputs[len(mel_outputs)] |
|
mel_output, gate_output, attention_weights = self.decode( |
|
decoder_input) |
|
mel_outputs += [mel_output.squeeze(1)] |
|
gate_outputs += [gate_output.squeeze()] |
|
alignments += [attention_weights] |
|
|
|
mel_outputs, gate_outputs, alignments = self.parse_decoder_outputs( |
|
mel_outputs, gate_outputs, alignments) |
|
|
|
return mel_outputs, gate_outputs, alignments |
|
|
|
def inference(self, memory, obs_and_lat, ret_has_eos=False): |
|
""" Decoder inference |
|
PARAMS |
|
------ |
|
memory: Encoder outputs |
|
obs_and_lat: Observed and latent attribute embeddings |
|
|
|
RETURNS |
|
------- |
|
mel_outputs: mel outputs from the decoder |
|
gate_outputs: gate outputs from the decoder |
|
alignments: sequence of attention weights from the decoder |
|
""" |
|
decoder_input = self.get_go_frame(memory) |
|
|
|
self.initialize_decoder_states(memory, obs_and_lat, mask=None) |
|
|
|
mel_outputs, gate_outputs, alignments = [], [], [] |
|
has_eos = False |
|
while True: |
|
decoder_input = self.prenet(decoder_input) |
|
mel_output, gate_output, alignment = self.decode(decoder_input) |
|
|
|
mel_outputs += [mel_output.squeeze(1)] |
|
gate_outputs += [gate_output] |
|
alignments += [alignment] |
|
|
|
if torch.sigmoid(gate_output.data) > self.gate_threshold: |
|
has_eos = True |
|
break |
|
elif len(mel_outputs) == self.max_decoder_steps: |
|
|
|
break |
|
|
|
decoder_input = mel_output |
|
|
|
mel_outputs, gate_outputs, alignments = self.parse_decoder_outputs( |
|
mel_outputs, gate_outputs, alignments) |
|
|
|
if ret_has_eos: |
|
return mel_outputs, gate_outputs, alignments, has_eos |
|
else: |
|
return mel_outputs, gate_outputs, alignments |
|
|
|
|
|
class Tacotron2(nn.Module): |
|
def __init__(self, hparams): |
|
super(Tacotron2, self).__init__() |
|
self.mask_padding = hparams.mask_padding |
|
self.fp16_run = hparams.fp16_run |
|
self.n_mel_channels = hparams.n_mel_channels |
|
self.n_frames_per_step = hparams.n_frames_per_step |
|
|
|
|
|
self.embedding = nn.Embedding( |
|
hparams.n_symbols, hparams.symbols_embedding_dim) |
|
std = sqrt(2.0 / (hparams.n_symbols + hparams.symbols_embedding_dim)) |
|
val = sqrt(3.0) * std |
|
self.embedding.weight.data.uniform_(-val, val) |
|
|
|
|
|
self.obs_embedding = None |
|
if hparams.obs_dim > 0: |
|
self.obs_embedding = nn.Embedding( |
|
hparams.obs_n_class, hparams.obs_dim) |
|
std = sqrt(2.0 / (hparams.obs_n_class + hparams.obs_dim)) |
|
val = sqrt(3.0) * std |
|
self.obs_embedding.weight.data.uniform_(-val, val) |
|
|
|
self.encoder = Encoder(hparams) |
|
self.decoder = Decoder(hparams) |
|
self.postnet = Postnet(hparams) |
|
|
|
self.lat_encoder = None |
|
if hparams.lat_dim > 0: |
|
self.lat_encoder = AudioEncoder(hparams) |
|
|
|
def parse_batch(self, batch): |
|
(text_padded, input_lengths, obs_labels, |
|
mel_padded, gate_padded, output_lengths) = batch |
|
text_padded = to_gpu(text_padded).long() |
|
input_lengths = to_gpu(input_lengths).long() |
|
obs_labels = to_gpu(obs_labels).long() |
|
max_len = torch.max(input_lengths.data).item() |
|
mel_padded = to_gpu(mel_padded).float() |
|
gate_padded = to_gpu(gate_padded).float() |
|
output_lengths = to_gpu(output_lengths).long() |
|
|
|
return ( |
|
(text_padded, input_lengths, obs_labels, |
|
mel_padded, max_len, output_lengths), |
|
(mel_padded, gate_padded)) |
|
|
|
def parse_output(self, outputs, output_lengths=None): |
|
if self.mask_padding and output_lengths is not None: |
|
mask = ~get_mask_from_lengths(output_lengths) |
|
mask = mask.expand(self.n_mel_channels, mask.size(0), mask.size(1)) |
|
mask = mask.permute(1, 0, 2) |
|
|
|
outputs[0].data.masked_fill_(mask, 0.0) |
|
outputs[1].data.masked_fill_(mask, 0.0) |
|
outputs[2].data.masked_fill_(mask[:, 0, :], 1e3) |
|
|
|
return outputs |
|
|
|
def forward(self, inputs): |
|
(text_inputs, text_lengths, obs_labels, |
|
mels, max_len, output_lengths) = inputs |
|
text_lengths, output_lengths = text_lengths.data, output_lengths.data |
|
|
|
embedded_inputs = self.embedding(text_inputs).transpose(1, 2) |
|
|
|
encoder_outputs = self.encoder(embedded_inputs, text_lengths) |
|
|
|
obs = None |
|
if self.obs_embedding is not None: |
|
obs = self.obs_embedding(obs_labels) |
|
|
|
lat, lat_mu, lat_logvar = None, None, None |
|
if self.lat_encoder is not None: |
|
(lat, lat_mu, lat_logvar) = self.lat_encoder(mels, output_lengths) |
|
|
|
obs_and_lat = [x for x in [obs, lat] if x is not None] |
|
if bool(obs_and_lat): |
|
obs_and_lat = torch.cat(obs_and_lat, dim=-1) |
|
else: |
|
obs_and_lat = None |
|
|
|
mel_outputs, gate_outputs, alignments = self.decoder( |
|
encoder_outputs, obs_and_lat, mels, memory_lengths=text_lengths) |
|
|
|
mel_outputs_postnet = self.postnet(mel_outputs) |
|
mel_outputs_postnet = mel_outputs + mel_outputs_postnet |
|
|
|
return self.parse_output( |
|
[mel_outputs, mel_outputs_postnet, gate_outputs, alignments, |
|
lat_mu, lat_logvar], |
|
output_lengths) |
|
|
|
def inference(self, inputs, obs_labels=None, lat=None, ret_has_eos=False): |
|
embedded_inputs = self.embedding(inputs).transpose(1, 2) |
|
encoder_outputs = self.encoder.inference(embedded_inputs) |
|
|
|
if obs_labels is None: |
|
obs_labels = torch.LongTensor(len(inputs)) |
|
obs_labels = obs_labels.to(inputs.device).zero_() |
|
|
|
obs = None |
|
if self.obs_embedding is not None: |
|
obs = self.obs_embedding(obs_labels) |
|
|
|
if self.lat_encoder is not None: |
|
if lat is None: |
|
lat = torch.FloatTensor(len(inputs), self.lat_encoder.lat_dim) |
|
lat = lat.to(inputs.device).zero_().type(encoder_outputs.type()) |
|
|
|
obs_and_lat = [x for x in [obs, lat] if x is not None] |
|
if bool(obs_and_lat): |
|
obs_and_lat = torch.cat(obs_and_lat, dim=-1) |
|
else: |
|
obs_and_lat = None |
|
|
|
mel_outputs, gate_outputs, alignments, has_eos = self.decoder.inference( |
|
encoder_outputs, obs_and_lat, ret_has_eos=True) |
|
|
|
mel_outputs_postnet = self.postnet(mel_outputs) |
|
mel_outputs_postnet = mel_outputs + mel_outputs_postnet |
|
|
|
outputs = self.parse_output( |
|
[mel_outputs, mel_outputs_postnet, gate_outputs, alignments]) |
|
|
|
if ret_has_eos: |
|
return outputs + [has_eos] |
|
else: |
|
return outputs |
|
|