|
|
|
|
|
|
|
|
|
|
|
|
|
"""CBHG related modules.""" |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
|
|
from torch.nn.utils.rnn import pack_padded_sequence |
|
from torch.nn.utils.rnn import pad_packed_sequence |
|
|
|
from espnet.nets.pytorch_backend.nets_utils import make_non_pad_mask |
|
|
|
|
|
class CBHGLoss(torch.nn.Module): |
|
"""Loss function module for CBHG.""" |
|
|
|
def __init__(self, use_masking=True): |
|
"""Initialize CBHG loss module. |
|
|
|
Args: |
|
use_masking (bool): Whether to mask padded part in loss calculation. |
|
|
|
""" |
|
super(CBHGLoss, self).__init__() |
|
self.use_masking = use_masking |
|
|
|
def forward(self, cbhg_outs, spcs, olens): |
|
"""Calculate forward propagation. |
|
|
|
Args: |
|
cbhg_outs (Tensor): Batch of CBHG outputs (B, Lmax, spc_dim). |
|
spcs (Tensor): Batch of groundtruth of spectrogram (B, Lmax, spc_dim). |
|
olens (LongTensor): Batch of the lengths of each sequence (B,). |
|
|
|
Returns: |
|
Tensor: L1 loss value |
|
Tensor: Mean square error loss value. |
|
|
|
""" |
|
|
|
if self.use_masking: |
|
mask = make_non_pad_mask(olens).unsqueeze(-1).to(spcs.device) |
|
spcs = spcs.masked_select(mask) |
|
cbhg_outs = cbhg_outs.masked_select(mask) |
|
|
|
|
|
cbhg_l1_loss = F.l1_loss(cbhg_outs, spcs) |
|
cbhg_mse_loss = F.mse_loss(cbhg_outs, spcs) |
|
|
|
return cbhg_l1_loss, cbhg_mse_loss |
|
|
|
|
|
class CBHG(torch.nn.Module): |
|
"""CBHG module to convert log Mel-filterbanks to linear spectrogram. |
|
|
|
This is a module of CBHG introduced |
|
in `Tacotron: Towards End-to-End Speech Synthesis`_. |
|
The CBHG converts the sequence of log Mel-filterbanks into linear spectrogram. |
|
|
|
.. _`Tacotron: Towards End-to-End Speech Synthesis`: |
|
https://arxiv.org/abs/1703.10135 |
|
|
|
""" |
|
|
|
def __init__( |
|
self, |
|
idim, |
|
odim, |
|
conv_bank_layers=8, |
|
conv_bank_chans=128, |
|
conv_proj_filts=3, |
|
conv_proj_chans=256, |
|
highway_layers=4, |
|
highway_units=128, |
|
gru_units=256, |
|
): |
|
"""Initialize CBHG module. |
|
|
|
Args: |
|
idim (int): Dimension of the inputs. |
|
odim (int): Dimension of the outputs. |
|
conv_bank_layers (int, optional): The number of convolution bank layers. |
|
conv_bank_chans (int, optional): The number of channels in convolution bank. |
|
conv_proj_filts (int, optional): |
|
Kernel size of convolutional projection layer. |
|
conv_proj_chans (int, optional): |
|
The number of channels in convolutional projection layer. |
|
highway_layers (int, optional): The number of highway network layers. |
|
highway_units (int, optional): The number of highway network units. |
|
gru_units (int, optional): The number of GRU units (for both directions). |
|
|
|
""" |
|
super(CBHG, self).__init__() |
|
self.idim = idim |
|
self.odim = odim |
|
self.conv_bank_layers = conv_bank_layers |
|
self.conv_bank_chans = conv_bank_chans |
|
self.conv_proj_filts = conv_proj_filts |
|
self.conv_proj_chans = conv_proj_chans |
|
self.highway_layers = highway_layers |
|
self.highway_units = highway_units |
|
self.gru_units = gru_units |
|
|
|
|
|
self.conv_bank = torch.nn.ModuleList() |
|
for k in range(1, self.conv_bank_layers + 1): |
|
if k % 2 != 0: |
|
padding = (k - 1) // 2 |
|
else: |
|
padding = ((k - 1) // 2, (k - 1) // 2 + 1) |
|
self.conv_bank += [ |
|
torch.nn.Sequential( |
|
torch.nn.ConstantPad1d(padding, 0.0), |
|
torch.nn.Conv1d( |
|
idim, self.conv_bank_chans, k, stride=1, padding=0, bias=True |
|
), |
|
torch.nn.BatchNorm1d(self.conv_bank_chans), |
|
torch.nn.ReLU(), |
|
) |
|
] |
|
|
|
|
|
self.max_pool = torch.nn.Sequential( |
|
torch.nn.ConstantPad1d((0, 1), 0.0), torch.nn.MaxPool1d(2, stride=1) |
|
) |
|
|
|
|
|
self.projections = torch.nn.Sequential( |
|
torch.nn.Conv1d( |
|
self.conv_bank_chans * self.conv_bank_layers, |
|
self.conv_proj_chans, |
|
self.conv_proj_filts, |
|
stride=1, |
|
padding=(self.conv_proj_filts - 1) // 2, |
|
bias=True, |
|
), |
|
torch.nn.BatchNorm1d(self.conv_proj_chans), |
|
torch.nn.ReLU(), |
|
torch.nn.Conv1d( |
|
self.conv_proj_chans, |
|
self.idim, |
|
self.conv_proj_filts, |
|
stride=1, |
|
padding=(self.conv_proj_filts - 1) // 2, |
|
bias=True, |
|
), |
|
torch.nn.BatchNorm1d(self.idim), |
|
) |
|
|
|
|
|
self.highways = torch.nn.ModuleList() |
|
self.highways += [torch.nn.Linear(idim, self.highway_units)] |
|
for _ in range(self.highway_layers): |
|
self.highways += [HighwayNet(self.highway_units)] |
|
|
|
|
|
self.gru = torch.nn.GRU( |
|
self.highway_units, |
|
gru_units // 2, |
|
num_layers=1, |
|
batch_first=True, |
|
bidirectional=True, |
|
) |
|
|
|
|
|
self.output = torch.nn.Linear(gru_units, odim, bias=True) |
|
|
|
def forward(self, xs, ilens): |
|
"""Calculate forward propagation. |
|
|
|
Args: |
|
xs (Tensor): Batch of the padded sequences of inputs (B, Tmax, idim). |
|
ilens (LongTensor): Batch of lengths of each input sequence (B,). |
|
|
|
Return: |
|
Tensor: Batch of the padded sequence of outputs (B, Tmax, odim). |
|
LongTensor: Batch of lengths of each output sequence (B,). |
|
|
|
""" |
|
xs = xs.transpose(1, 2) |
|
convs = [] |
|
for k in range(self.conv_bank_layers): |
|
convs += [self.conv_bank[k](xs)] |
|
convs = torch.cat(convs, dim=1) |
|
convs = self.max_pool(convs) |
|
convs = self.projections(convs).transpose(1, 2) |
|
xs = xs.transpose(1, 2) + convs |
|
|
|
for i in range(self.highway_layers + 1): |
|
xs = self.highways[i](xs) |
|
|
|
|
|
xs, ilens, sort_idx = self._sort_by_length(xs, ilens) |
|
|
|
|
|
|
|
total_length = xs.size(1) |
|
if not isinstance(ilens, torch.Tensor): |
|
ilens = torch.tensor(ilens) |
|
xs = pack_padded_sequence(xs, ilens.cpu(), batch_first=True) |
|
self.gru.flatten_parameters() |
|
xs, _ = self.gru(xs) |
|
xs, ilens = pad_packed_sequence(xs, batch_first=True, total_length=total_length) |
|
|
|
|
|
xs, ilens = self._revert_sort_by_length(xs, ilens, sort_idx) |
|
|
|
xs = self.output(xs) |
|
|
|
return xs, ilens |
|
|
|
def inference(self, x): |
|
"""Inference. |
|
|
|
Args: |
|
x (Tensor): The sequences of inputs (T, idim). |
|
|
|
Return: |
|
Tensor: The sequence of outputs (T, odim). |
|
|
|
""" |
|
assert len(x.size()) == 2 |
|
xs = x.unsqueeze(0) |
|
ilens = x.new([x.size(0)]).long() |
|
|
|
return self.forward(xs, ilens)[0][0] |
|
|
|
def _sort_by_length(self, xs, ilens): |
|
sort_ilens, sort_idx = ilens.sort(0, descending=True) |
|
return xs[sort_idx], ilens[sort_idx], sort_idx |
|
|
|
def _revert_sort_by_length(self, xs, ilens, sort_idx): |
|
_, revert_idx = sort_idx.sort(0) |
|
return xs[revert_idx], ilens[revert_idx] |
|
|
|
|
|
class HighwayNet(torch.nn.Module): |
|
"""Highway Network module. |
|
|
|
This is a module of Highway Network introduced in `Highway Networks`_. |
|
|
|
.. _`Highway Networks`: https://arxiv.org/abs/1505.00387 |
|
|
|
""" |
|
|
|
def __init__(self, idim): |
|
"""Initialize Highway Network module. |
|
|
|
Args: |
|
idim (int): Dimension of the inputs. |
|
|
|
""" |
|
super(HighwayNet, self).__init__() |
|
self.idim = idim |
|
self.projection = torch.nn.Sequential( |
|
torch.nn.Linear(idim, idim), torch.nn.ReLU() |
|
) |
|
self.gate = torch.nn.Sequential(torch.nn.Linear(idim, idim), torch.nn.Sigmoid()) |
|
|
|
def forward(self, x): |
|
"""Calculate forward propagation. |
|
|
|
Args: |
|
x (Tensor): Batch of inputs (B, ..., idim). |
|
|
|
Returns: |
|
Tensor: Batch of outputs, which are the same shape as inputs (B, ..., idim). |
|
|
|
""" |
|
proj = self.projection(x) |
|
gate = self.gate(x) |
|
return proj * gate + x * (1.0 - gate) |
|
|