|
|
|
|
|
|
|
|
|
|
|
|
|
"""Tacotron2 decoder related modules.""" |
|
|
|
import six |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
|
|
from espnet.nets.pytorch_backend.rnn.attentions import AttForwardTA |
|
|
|
|
|
def decoder_init(m): |
|
"""Initialize decoder parameters.""" |
|
if isinstance(m, torch.nn.Conv1d): |
|
torch.nn.init.xavier_uniform_(m.weight, torch.nn.init.calculate_gain("tanh")) |
|
|
|
|
|
class ZoneOutCell(torch.nn.Module): |
|
"""ZoneOut Cell module. |
|
|
|
This is a module of zoneout described in |
|
`Zoneout: Regularizing RNNs by Randomly Preserving Hidden Activations`_. |
|
This code is modified from `eladhoffer/seq2seq.pytorch`_. |
|
|
|
Examples: |
|
>>> lstm = torch.nn.LSTMCell(16, 32) |
|
>>> lstm = ZoneOutCell(lstm, 0.5) |
|
|
|
.. _`Zoneout: Regularizing RNNs by Randomly Preserving Hidden Activations`: |
|
https://arxiv.org/abs/1606.01305 |
|
|
|
.. _`eladhoffer/seq2seq.pytorch`: |
|
https://github.com/eladhoffer/seq2seq.pytorch |
|
|
|
""" |
|
|
|
def __init__(self, cell, zoneout_rate=0.1): |
|
"""Initialize zone out cell module. |
|
|
|
Args: |
|
cell (torch.nn.Module): Pytorch recurrent cell module |
|
e.g. `torch.nn.Module.LSTMCell`. |
|
zoneout_rate (float, optional): Probability of zoneout from 0.0 to 1.0. |
|
|
|
""" |
|
super(ZoneOutCell, self).__init__() |
|
self.cell = cell |
|
self.hidden_size = cell.hidden_size |
|
self.zoneout_rate = zoneout_rate |
|
if zoneout_rate > 1.0 or zoneout_rate < 0.0: |
|
raise ValueError( |
|
"zoneout probability must be in the range from 0.0 to 1.0." |
|
) |
|
|
|
def forward(self, inputs, hidden): |
|
"""Calculate forward propagation. |
|
|
|
Args: |
|
inputs (Tensor): Batch of input tensor (B, input_size). |
|
hidden (tuple): |
|
- Tensor: Batch of initial hidden states (B, hidden_size). |
|
- Tensor: Batch of initial cell states (B, hidden_size). |
|
|
|
Returns: |
|
tuple: |
|
- Tensor: Batch of next hidden states (B, hidden_size). |
|
- Tensor: Batch of next cell states (B, hidden_size). |
|
|
|
""" |
|
next_hidden = self.cell(inputs, hidden) |
|
next_hidden = self._zoneout(hidden, next_hidden, self.zoneout_rate) |
|
return next_hidden |
|
|
|
def _zoneout(self, h, next_h, prob): |
|
|
|
if isinstance(h, tuple): |
|
num_h = len(h) |
|
if not isinstance(prob, tuple): |
|
prob = tuple([prob] * num_h) |
|
return tuple( |
|
[self._zoneout(h[i], next_h[i], prob[i]) for i in range(num_h)] |
|
) |
|
|
|
if self.training: |
|
mask = h.new(*h.size()).bernoulli_(prob) |
|
return mask * h + (1 - mask) * next_h |
|
else: |
|
return prob * h + (1 - prob) * next_h |
|
|
|
|
|
class Prenet(torch.nn.Module): |
|
"""Prenet module for decoder of Spectrogram prediction network. |
|
|
|
This is a module of Prenet in the decoder of Spectrogram prediction network, |
|
which described in `Natural TTS |
|
Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions`_. |
|
The Prenet preforms nonlinear conversion |
|
of inputs before input to auto-regressive lstm, |
|
which helps to learn diagonal attentions. |
|
|
|
Note: |
|
This module alway applies dropout even in evaluation. |
|
See the detail in `Natural TTS Synthesis by |
|
Conditioning WaveNet on Mel Spectrogram Predictions`_. |
|
|
|
.. _`Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions`: |
|
https://arxiv.org/abs/1712.05884 |
|
|
|
""" |
|
|
|
def __init__(self, idim, n_layers=2, n_units=256, dropout_rate=0.5): |
|
"""Initialize prenet module. |
|
|
|
Args: |
|
idim (int): Dimension of the inputs. |
|
odim (int): Dimension of the outputs. |
|
n_layers (int, optional): The number of prenet layers. |
|
n_units (int, optional): The number of prenet units. |
|
|
|
""" |
|
super(Prenet, self).__init__() |
|
self.dropout_rate = dropout_rate |
|
self.prenet = torch.nn.ModuleList() |
|
for layer in six.moves.range(n_layers): |
|
n_inputs = idim if layer == 0 else n_units |
|
self.prenet += [ |
|
torch.nn.Sequential(torch.nn.Linear(n_inputs, n_units), torch.nn.ReLU()) |
|
] |
|
|
|
def forward(self, x): |
|
"""Calculate forward propagation. |
|
|
|
Args: |
|
x (Tensor): Batch of input tensors (B, ..., idim). |
|
|
|
Returns: |
|
Tensor: Batch of output tensors (B, ..., odim). |
|
|
|
""" |
|
for i in six.moves.range(len(self.prenet)): |
|
x = F.dropout(self.prenet[i](x), self.dropout_rate) |
|
return x |
|
|
|
|
|
class Postnet(torch.nn.Module): |
|
"""Postnet module for Spectrogram prediction network. |
|
|
|
This is a module of Postnet in Spectrogram prediction network, |
|
which described in `Natural TTS Synthesis by |
|
Conditioning WaveNet on Mel Spectrogram Predictions`_. |
|
The Postnet predicts refines the predicted |
|
Mel-filterbank of the decoder, |
|
which helps to compensate the detail sturcture of spectrogram. |
|
|
|
.. _`Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions`: |
|
https://arxiv.org/abs/1712.05884 |
|
|
|
""" |
|
|
|
def __init__( |
|
self, |
|
idim, |
|
odim, |
|
n_layers=5, |
|
n_chans=512, |
|
n_filts=5, |
|
dropout_rate=0.5, |
|
use_batch_norm=True, |
|
): |
|
"""Initialize postnet module. |
|
|
|
Args: |
|
idim (int): Dimension of the inputs. |
|
odim (int): Dimension of the outputs. |
|
n_layers (int, optional): The number of layers. |
|
n_filts (int, optional): The number of filter size. |
|
n_units (int, optional): The number of filter channels. |
|
use_batch_norm (bool, optional): Whether to use batch normalization.. |
|
dropout_rate (float, optional): Dropout rate.. |
|
|
|
""" |
|
super(Postnet, self).__init__() |
|
self.postnet = torch.nn.ModuleList() |
|
for layer in six.moves.range(n_layers - 1): |
|
ichans = odim if layer == 0 else n_chans |
|
ochans = odim if layer == n_layers - 1 else n_chans |
|
if use_batch_norm: |
|
self.postnet += [ |
|
torch.nn.Sequential( |
|
torch.nn.Conv1d( |
|
ichans, |
|
ochans, |
|
n_filts, |
|
stride=1, |
|
padding=(n_filts - 1) // 2, |
|
bias=False, |
|
), |
|
torch.nn.BatchNorm1d(ochans), |
|
torch.nn.Tanh(), |
|
torch.nn.Dropout(dropout_rate), |
|
) |
|
] |
|
else: |
|
self.postnet += [ |
|
torch.nn.Sequential( |
|
torch.nn.Conv1d( |
|
ichans, |
|
ochans, |
|
n_filts, |
|
stride=1, |
|
padding=(n_filts - 1) // 2, |
|
bias=False, |
|
), |
|
torch.nn.Tanh(), |
|
torch.nn.Dropout(dropout_rate), |
|
) |
|
] |
|
ichans = n_chans if n_layers != 1 else odim |
|
if use_batch_norm: |
|
self.postnet += [ |
|
torch.nn.Sequential( |
|
torch.nn.Conv1d( |
|
ichans, |
|
odim, |
|
n_filts, |
|
stride=1, |
|
padding=(n_filts - 1) // 2, |
|
bias=False, |
|
), |
|
torch.nn.BatchNorm1d(odim), |
|
torch.nn.Dropout(dropout_rate), |
|
) |
|
] |
|
else: |
|
self.postnet += [ |
|
torch.nn.Sequential( |
|
torch.nn.Conv1d( |
|
ichans, |
|
odim, |
|
n_filts, |
|
stride=1, |
|
padding=(n_filts - 1) // 2, |
|
bias=False, |
|
), |
|
torch.nn.Dropout(dropout_rate), |
|
) |
|
] |
|
|
|
def forward(self, xs): |
|
"""Calculate forward propagation. |
|
|
|
Args: |
|
xs (Tensor): Batch of the sequences of padded input tensors (B, idim, Tmax). |
|
|
|
Returns: |
|
Tensor: Batch of padded output tensor. (B, odim, Tmax). |
|
|
|
""" |
|
for i in six.moves.range(len(self.postnet)): |
|
xs = self.postnet[i](xs) |
|
return xs |
|
|
|
|
|
class Decoder(torch.nn.Module): |
|
"""Decoder module of Spectrogram prediction network. |
|
|
|
This is a module of decoder of Spectrogram prediction network in Tacotron2, |
|
which described in `Natural TTS |
|
Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions`_. |
|
The decoder generates the sequence of |
|
features from the sequence of the hidden states. |
|
|
|
.. _`Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions`: |
|
https://arxiv.org/abs/1712.05884 |
|
|
|
""" |
|
|
|
def __init__( |
|
self, |
|
idim, |
|
odim, |
|
att, |
|
dlayers=2, |
|
dunits=1024, |
|
prenet_layers=2, |
|
prenet_units=256, |
|
postnet_layers=5, |
|
postnet_chans=512, |
|
postnet_filts=5, |
|
output_activation_fn=None, |
|
cumulate_att_w=True, |
|
use_batch_norm=True, |
|
use_concate=True, |
|
dropout_rate=0.5, |
|
zoneout_rate=0.1, |
|
reduction_factor=1, |
|
): |
|
"""Initialize Tacotron2 decoder module. |
|
|
|
Args: |
|
idim (int): Dimension of the inputs. |
|
odim (int): Dimension of the outputs. |
|
att (torch.nn.Module): Instance of attention class. |
|
dlayers (int, optional): The number of decoder lstm layers. |
|
dunits (int, optional): The number of decoder lstm units. |
|
prenet_layers (int, optional): The number of prenet layers. |
|
prenet_units (int, optional): The number of prenet units. |
|
postnet_layers (int, optional): The number of postnet layers. |
|
postnet_filts (int, optional): The number of postnet filter size. |
|
postnet_chans (int, optional): The number of postnet filter channels. |
|
output_activation_fn (torch.nn.Module, optional): |
|
Activation function for outputs. |
|
cumulate_att_w (bool, optional): |
|
Whether to cumulate previous attention weight. |
|
use_batch_norm (bool, optional): Whether to use batch normalization. |
|
use_concate (bool, optional): Whether to concatenate encoder embedding |
|
with decoder lstm outputs. |
|
dropout_rate (float, optional): Dropout rate. |
|
zoneout_rate (float, optional): Zoneout rate. |
|
reduction_factor (int, optional): Reduction factor. |
|
|
|
""" |
|
super(Decoder, self).__init__() |
|
|
|
|
|
self.idim = idim |
|
self.odim = odim |
|
self.att = att |
|
self.output_activation_fn = output_activation_fn |
|
self.cumulate_att_w = cumulate_att_w |
|
self.use_concate = use_concate |
|
self.reduction_factor = reduction_factor |
|
|
|
|
|
if isinstance(self.att, AttForwardTA): |
|
self.use_att_extra_inputs = True |
|
else: |
|
self.use_att_extra_inputs = False |
|
|
|
|
|
prenet_units = prenet_units if prenet_layers != 0 else odim |
|
self.lstm = torch.nn.ModuleList() |
|
for layer in six.moves.range(dlayers): |
|
iunits = idim + prenet_units if layer == 0 else dunits |
|
lstm = torch.nn.LSTMCell(iunits, dunits) |
|
if zoneout_rate > 0.0: |
|
lstm = ZoneOutCell(lstm, zoneout_rate) |
|
self.lstm += [lstm] |
|
|
|
|
|
if prenet_layers > 0: |
|
self.prenet = Prenet( |
|
idim=odim, |
|
n_layers=prenet_layers, |
|
n_units=prenet_units, |
|
dropout_rate=dropout_rate, |
|
) |
|
else: |
|
self.prenet = None |
|
|
|
|
|
if postnet_layers > 0: |
|
self.postnet = Postnet( |
|
idim=idim, |
|
odim=odim, |
|
n_layers=postnet_layers, |
|
n_chans=postnet_chans, |
|
n_filts=postnet_filts, |
|
use_batch_norm=use_batch_norm, |
|
dropout_rate=dropout_rate, |
|
) |
|
else: |
|
self.postnet = None |
|
|
|
|
|
iunits = idim + dunits if use_concate else dunits |
|
self.feat_out = torch.nn.Linear(iunits, odim * reduction_factor, bias=False) |
|
self.prob_out = torch.nn.Linear(iunits, reduction_factor) |
|
|
|
|
|
self.apply(decoder_init) |
|
|
|
def _zero_state(self, hs): |
|
init_hs = hs.new_zeros(hs.size(0), self.lstm[0].hidden_size) |
|
return init_hs |
|
|
|
def forward(self, hs, hlens, ys): |
|
"""Calculate forward propagation. |
|
|
|
Args: |
|
hs (Tensor): Batch of the sequences of padded hidden states (B, Tmax, idim). |
|
hlens (LongTensor): Batch of lengths of each input batch (B,). |
|
ys (Tensor): |
|
Batch of the sequences of padded target features (B, Lmax, odim). |
|
|
|
Returns: |
|
Tensor: Batch of output tensors after postnet (B, Lmax, odim). |
|
Tensor: Batch of output tensors before postnet (B, Lmax, odim). |
|
Tensor: Batch of logits of stop prediction (B, Lmax). |
|
Tensor: Batch of attention weights (B, Lmax, Tmax). |
|
|
|
Note: |
|
This computation is performed in teacher-forcing manner. |
|
|
|
""" |
|
|
|
if self.reduction_factor > 1: |
|
ys = ys[:, self.reduction_factor - 1 :: self.reduction_factor] |
|
|
|
|
|
hlens = list(map(int, hlens)) |
|
|
|
|
|
c_list = [self._zero_state(hs)] |
|
z_list = [self._zero_state(hs)] |
|
for _ in six.moves.range(1, len(self.lstm)): |
|
c_list += [self._zero_state(hs)] |
|
z_list += [self._zero_state(hs)] |
|
prev_out = hs.new_zeros(hs.size(0), self.odim) |
|
|
|
|
|
prev_att_w = None |
|
self.att.reset() |
|
|
|
|
|
outs, logits, att_ws = [], [], [] |
|
for y in ys.transpose(0, 1): |
|
if self.use_att_extra_inputs: |
|
att_c, att_w = self.att(hs, hlens, z_list[0], prev_att_w, prev_out) |
|
else: |
|
att_c, att_w = self.att(hs, hlens, z_list[0], prev_att_w) |
|
prenet_out = self.prenet(prev_out) if self.prenet is not None else prev_out |
|
xs = torch.cat([att_c, prenet_out], dim=1) |
|
z_list[0], c_list[0] = self.lstm[0](xs, (z_list[0], c_list[0])) |
|
for i in six.moves.range(1, len(self.lstm)): |
|
z_list[i], c_list[i] = self.lstm[i]( |
|
z_list[i - 1], (z_list[i], c_list[i]) |
|
) |
|
zcs = ( |
|
torch.cat([z_list[-1], att_c], dim=1) |
|
if self.use_concate |
|
else z_list[-1] |
|
) |
|
outs += [self.feat_out(zcs).view(hs.size(0), self.odim, -1)] |
|
logits += [self.prob_out(zcs)] |
|
att_ws += [att_w] |
|
prev_out = y |
|
if self.cumulate_att_w and prev_att_w is not None: |
|
prev_att_w = prev_att_w + att_w |
|
else: |
|
prev_att_w = att_w |
|
|
|
logits = torch.cat(logits, dim=1) |
|
before_outs = torch.cat(outs, dim=2) |
|
att_ws = torch.stack(att_ws, dim=1) |
|
|
|
if self.reduction_factor > 1: |
|
before_outs = before_outs.view( |
|
before_outs.size(0), self.odim, -1 |
|
) |
|
|
|
if self.postnet is not None: |
|
after_outs = before_outs + self.postnet(before_outs) |
|
else: |
|
after_outs = before_outs |
|
before_outs = before_outs.transpose(2, 1) |
|
after_outs = after_outs.transpose(2, 1) |
|
logits = logits |
|
|
|
|
|
if self.output_activation_fn is not None: |
|
before_outs = self.output_activation_fn(before_outs) |
|
after_outs = self.output_activation_fn(after_outs) |
|
|
|
return after_outs, before_outs, logits, att_ws |
|
|
|
def inference( |
|
self, |
|
h, |
|
threshold=0.5, |
|
minlenratio=0.0, |
|
maxlenratio=10.0, |
|
use_att_constraint=False, |
|
backward_window=None, |
|
forward_window=None, |
|
): |
|
"""Generate the sequence of features given the sequences of characters. |
|
|
|
Args: |
|
h (Tensor): Input sequence of encoder hidden states (T, C). |
|
threshold (float, optional): Threshold to stop generation. |
|
minlenratio (float, optional): Minimum length ratio. |
|
If set to 1.0 and the length of input is 10, |
|
the minimum length of outputs will be 10 * 1 = 10. |
|
minlenratio (float, optional): Minimum length ratio. |
|
If set to 10 and the length of input is 10, |
|
the maximum length of outputs will be 10 * 10 = 100. |
|
use_att_constraint (bool): |
|
Whether to apply attention constraint introduced in `Deep Voice 3`_. |
|
backward_window (int): Backward window size in attention constraint. |
|
forward_window (int): Forward window size in attention constraint. |
|
|
|
Returns: |
|
Tensor: Output sequence of features (L, odim). |
|
Tensor: Output sequence of stop probabilities (L,). |
|
Tensor: Attention weights (L, T). |
|
|
|
Note: |
|
This computation is performed in auto-regressive manner. |
|
|
|
.. _`Deep Voice 3`: https://arxiv.org/abs/1710.07654 |
|
|
|
""" |
|
|
|
assert len(h.size()) == 2 |
|
hs = h.unsqueeze(0) |
|
ilens = [h.size(0)] |
|
maxlen = int(h.size(0) * maxlenratio) |
|
minlen = int(h.size(0) * minlenratio) |
|
|
|
|
|
c_list = [self._zero_state(hs)] |
|
z_list = [self._zero_state(hs)] |
|
for _ in six.moves.range(1, len(self.lstm)): |
|
c_list += [self._zero_state(hs)] |
|
z_list += [self._zero_state(hs)] |
|
prev_out = hs.new_zeros(1, self.odim) |
|
|
|
|
|
prev_att_w = None |
|
self.att.reset() |
|
|
|
|
|
if use_att_constraint: |
|
last_attended_idx = 0 |
|
else: |
|
last_attended_idx = None |
|
|
|
|
|
idx = 0 |
|
outs, att_ws, probs = [], [], [] |
|
while True: |
|
|
|
idx += self.reduction_factor |
|
|
|
|
|
if self.use_att_extra_inputs: |
|
att_c, att_w = self.att( |
|
hs, |
|
ilens, |
|
z_list[0], |
|
prev_att_w, |
|
prev_out, |
|
last_attended_idx=last_attended_idx, |
|
backward_window=backward_window, |
|
forward_window=forward_window, |
|
) |
|
else: |
|
att_c, att_w = self.att( |
|
hs, |
|
ilens, |
|
z_list[0], |
|
prev_att_w, |
|
last_attended_idx=last_attended_idx, |
|
backward_window=backward_window, |
|
forward_window=forward_window, |
|
) |
|
|
|
att_ws += [att_w] |
|
prenet_out = self.prenet(prev_out) if self.prenet is not None else prev_out |
|
xs = torch.cat([att_c, prenet_out], dim=1) |
|
z_list[0], c_list[0] = self.lstm[0](xs, (z_list[0], c_list[0])) |
|
for i in six.moves.range(1, len(self.lstm)): |
|
z_list[i], c_list[i] = self.lstm[i]( |
|
z_list[i - 1], (z_list[i], c_list[i]) |
|
) |
|
zcs = ( |
|
torch.cat([z_list[-1], att_c], dim=1) |
|
if self.use_concate |
|
else z_list[-1] |
|
) |
|
outs += [self.feat_out(zcs).view(1, self.odim, -1)] |
|
probs += [torch.sigmoid(self.prob_out(zcs))[0]] |
|
if self.output_activation_fn is not None: |
|
prev_out = self.output_activation_fn(outs[-1][:, :, -1]) |
|
else: |
|
prev_out = outs[-1][:, :, -1] |
|
if self.cumulate_att_w and prev_att_w is not None: |
|
prev_att_w = prev_att_w + att_w |
|
else: |
|
prev_att_w = att_w |
|
if use_att_constraint: |
|
last_attended_idx = int(att_w.argmax()) |
|
|
|
|
|
if int(sum(probs[-1] >= threshold)) > 0 or idx >= maxlen: |
|
|
|
if idx < minlen: |
|
continue |
|
outs = torch.cat(outs, dim=2) |
|
if self.postnet is not None: |
|
outs = outs + self.postnet(outs) |
|
outs = outs.transpose(2, 1).squeeze(0) |
|
probs = torch.cat(probs, dim=0) |
|
att_ws = torch.cat(att_ws, dim=0) |
|
break |
|
|
|
if self.output_activation_fn is not None: |
|
outs = self.output_activation_fn(outs) |
|
|
|
return outs, probs, att_ws |
|
|
|
def calculate_all_attentions(self, hs, hlens, ys): |
|
"""Calculate all of the attention weights. |
|
|
|
Args: |
|
hs (Tensor): Batch of the sequences of padded hidden states (B, Tmax, idim). |
|
hlens (LongTensor): Batch of lengths of each input batch (B,). |
|
ys (Tensor): |
|
Batch of the sequences of padded target features (B, Lmax, odim). |
|
|
|
Returns: |
|
numpy.ndarray: Batch of attention weights (B, Lmax, Tmax). |
|
|
|
Note: |
|
This computation is performed in teacher-forcing manner. |
|
|
|
""" |
|
|
|
if self.reduction_factor > 1: |
|
ys = ys[:, self.reduction_factor - 1 :: self.reduction_factor] |
|
|
|
|
|
hlens = list(map(int, hlens)) |
|
|
|
|
|
c_list = [self._zero_state(hs)] |
|
z_list = [self._zero_state(hs)] |
|
for _ in six.moves.range(1, len(self.lstm)): |
|
c_list += [self._zero_state(hs)] |
|
z_list += [self._zero_state(hs)] |
|
prev_out = hs.new_zeros(hs.size(0), self.odim) |
|
|
|
|
|
prev_att_w = None |
|
self.att.reset() |
|
|
|
|
|
att_ws = [] |
|
for y in ys.transpose(0, 1): |
|
if self.use_att_extra_inputs: |
|
att_c, att_w = self.att(hs, hlens, z_list[0], prev_att_w, prev_out) |
|
else: |
|
att_c, att_w = self.att(hs, hlens, z_list[0], prev_att_w) |
|
att_ws += [att_w] |
|
prenet_out = self.prenet(prev_out) if self.prenet is not None else prev_out |
|
xs = torch.cat([att_c, prenet_out], dim=1) |
|
z_list[0], c_list[0] = self.lstm[0](xs, (z_list[0], c_list[0])) |
|
for i in six.moves.range(1, len(self.lstm)): |
|
z_list[i], c_list[i] = self.lstm[i]( |
|
z_list[i - 1], (z_list[i], c_list[i]) |
|
) |
|
prev_out = y |
|
if self.cumulate_att_w and prev_att_w is not None: |
|
prev_att_w = prev_att_w + att_w |
|
else: |
|
prev_att_w = att_w |
|
|
|
att_ws = torch.stack(att_ws, dim=1) |
|
|
|
return att_ws |
|
|