tobiasc's picture
Initial commit
ad16788
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright 2019 Nagoya University (Tomoki Hayashi)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Tacotron2 decoder related modules."""
import six
import torch
import torch.nn.functional as F
from espnet.nets.pytorch_backend.rnn.attentions import AttForwardTA
def decoder_init(m):
"""Initialize decoder parameters."""
if isinstance(m, torch.nn.Conv1d):
torch.nn.init.xavier_uniform_(m.weight, torch.nn.init.calculate_gain("tanh"))
class ZoneOutCell(torch.nn.Module):
"""ZoneOut Cell module.
This is a module of zoneout described in
`Zoneout: Regularizing RNNs by Randomly Preserving Hidden Activations`_.
This code is modified from `eladhoffer/seq2seq.pytorch`_.
Examples:
>>> lstm = torch.nn.LSTMCell(16, 32)
>>> lstm = ZoneOutCell(lstm, 0.5)
.. _`Zoneout: Regularizing RNNs by Randomly Preserving Hidden Activations`:
https://arxiv.org/abs/1606.01305
.. _`eladhoffer/seq2seq.pytorch`:
https://github.com/eladhoffer/seq2seq.pytorch
"""
def __init__(self, cell, zoneout_rate=0.1):
"""Initialize zone out cell module.
Args:
cell (torch.nn.Module): Pytorch recurrent cell module
e.g. `torch.nn.Module.LSTMCell`.
zoneout_rate (float, optional): Probability of zoneout from 0.0 to 1.0.
"""
super(ZoneOutCell, self).__init__()
self.cell = cell
self.hidden_size = cell.hidden_size
self.zoneout_rate = zoneout_rate
if zoneout_rate > 1.0 or zoneout_rate < 0.0:
raise ValueError(
"zoneout probability must be in the range from 0.0 to 1.0."
)
def forward(self, inputs, hidden):
"""Calculate forward propagation.
Args:
inputs (Tensor): Batch of input tensor (B, input_size).
hidden (tuple):
- Tensor: Batch of initial hidden states (B, hidden_size).
- Tensor: Batch of initial cell states (B, hidden_size).
Returns:
tuple:
- Tensor: Batch of next hidden states (B, hidden_size).
- Tensor: Batch of next cell states (B, hidden_size).
"""
next_hidden = self.cell(inputs, hidden)
next_hidden = self._zoneout(hidden, next_hidden, self.zoneout_rate)
return next_hidden
def _zoneout(self, h, next_h, prob):
# apply recursively
if isinstance(h, tuple):
num_h = len(h)
if not isinstance(prob, tuple):
prob = tuple([prob] * num_h)
return tuple(
[self._zoneout(h[i], next_h[i], prob[i]) for i in range(num_h)]
)
if self.training:
mask = h.new(*h.size()).bernoulli_(prob)
return mask * h + (1 - mask) * next_h
else:
return prob * h + (1 - prob) * next_h
class Prenet(torch.nn.Module):
"""Prenet module for decoder of Spectrogram prediction network.
This is a module of Prenet in the decoder of Spectrogram prediction network,
which described in `Natural TTS
Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions`_.
The Prenet preforms nonlinear conversion
of inputs before input to auto-regressive lstm,
which helps to learn diagonal attentions.
Note:
This module alway applies dropout even in evaluation.
See the detail in `Natural TTS Synthesis by
Conditioning WaveNet on Mel Spectrogram Predictions`_.
.. _`Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions`:
https://arxiv.org/abs/1712.05884
"""
def __init__(self, idim, n_layers=2, n_units=256, dropout_rate=0.5):
"""Initialize prenet module.
Args:
idim (int): Dimension of the inputs.
odim (int): Dimension of the outputs.
n_layers (int, optional): The number of prenet layers.
n_units (int, optional): The number of prenet units.
"""
super(Prenet, self).__init__()
self.dropout_rate = dropout_rate
self.prenet = torch.nn.ModuleList()
for layer in six.moves.range(n_layers):
n_inputs = idim if layer == 0 else n_units
self.prenet += [
torch.nn.Sequential(torch.nn.Linear(n_inputs, n_units), torch.nn.ReLU())
]
def forward(self, x):
"""Calculate forward propagation.
Args:
x (Tensor): Batch of input tensors (B, ..., idim).
Returns:
Tensor: Batch of output tensors (B, ..., odim).
"""
for i in six.moves.range(len(self.prenet)):
x = F.dropout(self.prenet[i](x), self.dropout_rate)
return x
class Postnet(torch.nn.Module):
"""Postnet module for Spectrogram prediction network.
This is a module of Postnet in Spectrogram prediction network,
which described in `Natural TTS Synthesis by
Conditioning WaveNet on Mel Spectrogram Predictions`_.
The Postnet predicts refines the predicted
Mel-filterbank of the decoder,
which helps to compensate the detail sturcture of spectrogram.
.. _`Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions`:
https://arxiv.org/abs/1712.05884
"""
def __init__(
self,
idim,
odim,
n_layers=5,
n_chans=512,
n_filts=5,
dropout_rate=0.5,
use_batch_norm=True,
):
"""Initialize postnet module.
Args:
idim (int): Dimension of the inputs.
odim (int): Dimension of the outputs.
n_layers (int, optional): The number of layers.
n_filts (int, optional): The number of filter size.
n_units (int, optional): The number of filter channels.
use_batch_norm (bool, optional): Whether to use batch normalization..
dropout_rate (float, optional): Dropout rate..
"""
super(Postnet, self).__init__()
self.postnet = torch.nn.ModuleList()
for layer in six.moves.range(n_layers - 1):
ichans = odim if layer == 0 else n_chans
ochans = odim if layer == n_layers - 1 else n_chans
if use_batch_norm:
self.postnet += [
torch.nn.Sequential(
torch.nn.Conv1d(
ichans,
ochans,
n_filts,
stride=1,
padding=(n_filts - 1) // 2,
bias=False,
),
torch.nn.BatchNorm1d(ochans),
torch.nn.Tanh(),
torch.nn.Dropout(dropout_rate),
)
]
else:
self.postnet += [
torch.nn.Sequential(
torch.nn.Conv1d(
ichans,
ochans,
n_filts,
stride=1,
padding=(n_filts - 1) // 2,
bias=False,
),
torch.nn.Tanh(),
torch.nn.Dropout(dropout_rate),
)
]
ichans = n_chans if n_layers != 1 else odim
if use_batch_norm:
self.postnet += [
torch.nn.Sequential(
torch.nn.Conv1d(
ichans,
odim,
n_filts,
stride=1,
padding=(n_filts - 1) // 2,
bias=False,
),
torch.nn.BatchNorm1d(odim),
torch.nn.Dropout(dropout_rate),
)
]
else:
self.postnet += [
torch.nn.Sequential(
torch.nn.Conv1d(
ichans,
odim,
n_filts,
stride=1,
padding=(n_filts - 1) // 2,
bias=False,
),
torch.nn.Dropout(dropout_rate),
)
]
def forward(self, xs):
"""Calculate forward propagation.
Args:
xs (Tensor): Batch of the sequences of padded input tensors (B, idim, Tmax).
Returns:
Tensor: Batch of padded output tensor. (B, odim, Tmax).
"""
for i in six.moves.range(len(self.postnet)):
xs = self.postnet[i](xs)
return xs
class Decoder(torch.nn.Module):
"""Decoder module of Spectrogram prediction network.
This is a module of decoder of Spectrogram prediction network in Tacotron2,
which described in `Natural TTS
Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions`_.
The decoder generates the sequence of
features from the sequence of the hidden states.
.. _`Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions`:
https://arxiv.org/abs/1712.05884
"""
def __init__(
self,
idim,
odim,
att,
dlayers=2,
dunits=1024,
prenet_layers=2,
prenet_units=256,
postnet_layers=5,
postnet_chans=512,
postnet_filts=5,
output_activation_fn=None,
cumulate_att_w=True,
use_batch_norm=True,
use_concate=True,
dropout_rate=0.5,
zoneout_rate=0.1,
reduction_factor=1,
):
"""Initialize Tacotron2 decoder module.
Args:
idim (int): Dimension of the inputs.
odim (int): Dimension of the outputs.
att (torch.nn.Module): Instance of attention class.
dlayers (int, optional): The number of decoder lstm layers.
dunits (int, optional): The number of decoder lstm units.
prenet_layers (int, optional): The number of prenet layers.
prenet_units (int, optional): The number of prenet units.
postnet_layers (int, optional): The number of postnet layers.
postnet_filts (int, optional): The number of postnet filter size.
postnet_chans (int, optional): The number of postnet filter channels.
output_activation_fn (torch.nn.Module, optional):
Activation function for outputs.
cumulate_att_w (bool, optional):
Whether to cumulate previous attention weight.
use_batch_norm (bool, optional): Whether to use batch normalization.
use_concate (bool, optional): Whether to concatenate encoder embedding
with decoder lstm outputs.
dropout_rate (float, optional): Dropout rate.
zoneout_rate (float, optional): Zoneout rate.
reduction_factor (int, optional): Reduction factor.
"""
super(Decoder, self).__init__()
# store the hyperparameters
self.idim = idim
self.odim = odim
self.att = att
self.output_activation_fn = output_activation_fn
self.cumulate_att_w = cumulate_att_w
self.use_concate = use_concate
self.reduction_factor = reduction_factor
# check attention type
if isinstance(self.att, AttForwardTA):
self.use_att_extra_inputs = True
else:
self.use_att_extra_inputs = False
# define lstm network
prenet_units = prenet_units if prenet_layers != 0 else odim
self.lstm = torch.nn.ModuleList()
for layer in six.moves.range(dlayers):
iunits = idim + prenet_units if layer == 0 else dunits
lstm = torch.nn.LSTMCell(iunits, dunits)
if zoneout_rate > 0.0:
lstm = ZoneOutCell(lstm, zoneout_rate)
self.lstm += [lstm]
# define prenet
if prenet_layers > 0:
self.prenet = Prenet(
idim=odim,
n_layers=prenet_layers,
n_units=prenet_units,
dropout_rate=dropout_rate,
)
else:
self.prenet = None
# define postnet
if postnet_layers > 0:
self.postnet = Postnet(
idim=idim,
odim=odim,
n_layers=postnet_layers,
n_chans=postnet_chans,
n_filts=postnet_filts,
use_batch_norm=use_batch_norm,
dropout_rate=dropout_rate,
)
else:
self.postnet = None
# define projection layers
iunits = idim + dunits if use_concate else dunits
self.feat_out = torch.nn.Linear(iunits, odim * reduction_factor, bias=False)
self.prob_out = torch.nn.Linear(iunits, reduction_factor)
# initialize
self.apply(decoder_init)
def _zero_state(self, hs):
init_hs = hs.new_zeros(hs.size(0), self.lstm[0].hidden_size)
return init_hs
def forward(self, hs, hlens, ys):
"""Calculate forward propagation.
Args:
hs (Tensor): Batch of the sequences of padded hidden states (B, Tmax, idim).
hlens (LongTensor): Batch of lengths of each input batch (B,).
ys (Tensor):
Batch of the sequences of padded target features (B, Lmax, odim).
Returns:
Tensor: Batch of output tensors after postnet (B, Lmax, odim).
Tensor: Batch of output tensors before postnet (B, Lmax, odim).
Tensor: Batch of logits of stop prediction (B, Lmax).
Tensor: Batch of attention weights (B, Lmax, Tmax).
Note:
This computation is performed in teacher-forcing manner.
"""
# thin out frames (B, Lmax, odim) -> (B, Lmax/r, odim)
if self.reduction_factor > 1:
ys = ys[:, self.reduction_factor - 1 :: self.reduction_factor]
# length list should be list of int
hlens = list(map(int, hlens))
# initialize hidden states of decoder
c_list = [self._zero_state(hs)]
z_list = [self._zero_state(hs)]
for _ in six.moves.range(1, len(self.lstm)):
c_list += [self._zero_state(hs)]
z_list += [self._zero_state(hs)]
prev_out = hs.new_zeros(hs.size(0), self.odim)
# initialize attention
prev_att_w = None
self.att.reset()
# loop for an output sequence
outs, logits, att_ws = [], [], []
for y in ys.transpose(0, 1):
if self.use_att_extra_inputs:
att_c, att_w = self.att(hs, hlens, z_list[0], prev_att_w, prev_out)
else:
att_c, att_w = self.att(hs, hlens, z_list[0], prev_att_w)
prenet_out = self.prenet(prev_out) if self.prenet is not None else prev_out
xs = torch.cat([att_c, prenet_out], dim=1)
z_list[0], c_list[0] = self.lstm[0](xs, (z_list[0], c_list[0]))
for i in six.moves.range(1, len(self.lstm)):
z_list[i], c_list[i] = self.lstm[i](
z_list[i - 1], (z_list[i], c_list[i])
)
zcs = (
torch.cat([z_list[-1], att_c], dim=1)
if self.use_concate
else z_list[-1]
)
outs += [self.feat_out(zcs).view(hs.size(0), self.odim, -1)]
logits += [self.prob_out(zcs)]
att_ws += [att_w]
prev_out = y # teacher forcing
if self.cumulate_att_w and prev_att_w is not None:
prev_att_w = prev_att_w + att_w # Note: error when use +=
else:
prev_att_w = att_w
logits = torch.cat(logits, dim=1) # (B, Lmax)
before_outs = torch.cat(outs, dim=2) # (B, odim, Lmax)
att_ws = torch.stack(att_ws, dim=1) # (B, Lmax, Tmax)
if self.reduction_factor > 1:
before_outs = before_outs.view(
before_outs.size(0), self.odim, -1
) # (B, odim, Lmax)
if self.postnet is not None:
after_outs = before_outs + self.postnet(before_outs) # (B, odim, Lmax)
else:
after_outs = before_outs
before_outs = before_outs.transpose(2, 1) # (B, Lmax, odim)
after_outs = after_outs.transpose(2, 1) # (B, Lmax, odim)
logits = logits
# apply activation function for scaling
if self.output_activation_fn is not None:
before_outs = self.output_activation_fn(before_outs)
after_outs = self.output_activation_fn(after_outs)
return after_outs, before_outs, logits, att_ws
def inference(
self,
h,
threshold=0.5,
minlenratio=0.0,
maxlenratio=10.0,
use_att_constraint=False,
backward_window=None,
forward_window=None,
):
"""Generate the sequence of features given the sequences of characters.
Args:
h (Tensor): Input sequence of encoder hidden states (T, C).
threshold (float, optional): Threshold to stop generation.
minlenratio (float, optional): Minimum length ratio.
If set to 1.0 and the length of input is 10,
the minimum length of outputs will be 10 * 1 = 10.
minlenratio (float, optional): Minimum length ratio.
If set to 10 and the length of input is 10,
the maximum length of outputs will be 10 * 10 = 100.
use_att_constraint (bool):
Whether to apply attention constraint introduced in `Deep Voice 3`_.
backward_window (int): Backward window size in attention constraint.
forward_window (int): Forward window size in attention constraint.
Returns:
Tensor: Output sequence of features (L, odim).
Tensor: Output sequence of stop probabilities (L,).
Tensor: Attention weights (L, T).
Note:
This computation is performed in auto-regressive manner.
.. _`Deep Voice 3`: https://arxiv.org/abs/1710.07654
"""
# setup
assert len(h.size()) == 2
hs = h.unsqueeze(0)
ilens = [h.size(0)]
maxlen = int(h.size(0) * maxlenratio)
minlen = int(h.size(0) * minlenratio)
# initialize hidden states of decoder
c_list = [self._zero_state(hs)]
z_list = [self._zero_state(hs)]
for _ in six.moves.range(1, len(self.lstm)):
c_list += [self._zero_state(hs)]
z_list += [self._zero_state(hs)]
prev_out = hs.new_zeros(1, self.odim)
# initialize attention
prev_att_w = None
self.att.reset()
# setup for attention constraint
if use_att_constraint:
last_attended_idx = 0
else:
last_attended_idx = None
# loop for an output sequence
idx = 0
outs, att_ws, probs = [], [], []
while True:
# updated index
idx += self.reduction_factor
# decoder calculation
if self.use_att_extra_inputs:
att_c, att_w = self.att(
hs,
ilens,
z_list[0],
prev_att_w,
prev_out,
last_attended_idx=last_attended_idx,
backward_window=backward_window,
forward_window=forward_window,
)
else:
att_c, att_w = self.att(
hs,
ilens,
z_list[0],
prev_att_w,
last_attended_idx=last_attended_idx,
backward_window=backward_window,
forward_window=forward_window,
)
att_ws += [att_w]
prenet_out = self.prenet(prev_out) if self.prenet is not None else prev_out
xs = torch.cat([att_c, prenet_out], dim=1)
z_list[0], c_list[0] = self.lstm[0](xs, (z_list[0], c_list[0]))
for i in six.moves.range(1, len(self.lstm)):
z_list[i], c_list[i] = self.lstm[i](
z_list[i - 1], (z_list[i], c_list[i])
)
zcs = (
torch.cat([z_list[-1], att_c], dim=1)
if self.use_concate
else z_list[-1]
)
outs += [self.feat_out(zcs).view(1, self.odim, -1)] # [(1, odim, r), ...]
probs += [torch.sigmoid(self.prob_out(zcs))[0]] # [(r), ...]
if self.output_activation_fn is not None:
prev_out = self.output_activation_fn(outs[-1][:, :, -1]) # (1, odim)
else:
prev_out = outs[-1][:, :, -1] # (1, odim)
if self.cumulate_att_w and prev_att_w is not None:
prev_att_w = prev_att_w + att_w # Note: error when use +=
else:
prev_att_w = att_w
if use_att_constraint:
last_attended_idx = int(att_w.argmax())
# check whether to finish generation
if int(sum(probs[-1] >= threshold)) > 0 or idx >= maxlen:
# check mininum length
if idx < minlen:
continue
outs = torch.cat(outs, dim=2) # (1, odim, L)
if self.postnet is not None:
outs = outs + self.postnet(outs) # (1, odim, L)
outs = outs.transpose(2, 1).squeeze(0) # (L, odim)
probs = torch.cat(probs, dim=0)
att_ws = torch.cat(att_ws, dim=0)
break
if self.output_activation_fn is not None:
outs = self.output_activation_fn(outs)
return outs, probs, att_ws
def calculate_all_attentions(self, hs, hlens, ys):
"""Calculate all of the attention weights.
Args:
hs (Tensor): Batch of the sequences of padded hidden states (B, Tmax, idim).
hlens (LongTensor): Batch of lengths of each input batch (B,).
ys (Tensor):
Batch of the sequences of padded target features (B, Lmax, odim).
Returns:
numpy.ndarray: Batch of attention weights (B, Lmax, Tmax).
Note:
This computation is performed in teacher-forcing manner.
"""
# thin out frames (B, Lmax, odim) -> (B, Lmax/r, odim)
if self.reduction_factor > 1:
ys = ys[:, self.reduction_factor - 1 :: self.reduction_factor]
# length list should be list of int
hlens = list(map(int, hlens))
# initialize hidden states of decoder
c_list = [self._zero_state(hs)]
z_list = [self._zero_state(hs)]
for _ in six.moves.range(1, len(self.lstm)):
c_list += [self._zero_state(hs)]
z_list += [self._zero_state(hs)]
prev_out = hs.new_zeros(hs.size(0), self.odim)
# initialize attention
prev_att_w = None
self.att.reset()
# loop for an output sequence
att_ws = []
for y in ys.transpose(0, 1):
if self.use_att_extra_inputs:
att_c, att_w = self.att(hs, hlens, z_list[0], prev_att_w, prev_out)
else:
att_c, att_w = self.att(hs, hlens, z_list[0], prev_att_w)
att_ws += [att_w]
prenet_out = self.prenet(prev_out) if self.prenet is not None else prev_out
xs = torch.cat([att_c, prenet_out], dim=1)
z_list[0], c_list[0] = self.lstm[0](xs, (z_list[0], c_list[0]))
for i in six.moves.range(1, len(self.lstm)):
z_list[i], c_list[i] = self.lstm[i](
z_list[i - 1], (z_list[i], c_list[i])
)
prev_out = y # teacher forcing
if self.cumulate_att_w and prev_att_w is not None:
prev_att_w = prev_att_w + att_w # Note: error when use +=
else:
prev_att_w = att_w
att_ws = torch.stack(att_ws, dim=1) # (B, Lmax, Tmax)
return att_ws