File size: 9,099 Bytes
ad16788 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 |
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright 2019 Nagoya University (Tomoki Hayashi)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""CBHG related modules."""
import torch
import torch.nn.functional as F
from torch.nn.utils.rnn import pack_padded_sequence
from torch.nn.utils.rnn import pad_packed_sequence
from espnet.nets.pytorch_backend.nets_utils import make_non_pad_mask
class CBHGLoss(torch.nn.Module):
"""Loss function module for CBHG."""
def __init__(self, use_masking=True):
"""Initialize CBHG loss module.
Args:
use_masking (bool): Whether to mask padded part in loss calculation.
"""
super(CBHGLoss, self).__init__()
self.use_masking = use_masking
def forward(self, cbhg_outs, spcs, olens):
"""Calculate forward propagation.
Args:
cbhg_outs (Tensor): Batch of CBHG outputs (B, Lmax, spc_dim).
spcs (Tensor): Batch of groundtruth of spectrogram (B, Lmax, spc_dim).
olens (LongTensor): Batch of the lengths of each sequence (B,).
Returns:
Tensor: L1 loss value
Tensor: Mean square error loss value.
"""
# perform masking for padded values
if self.use_masking:
mask = make_non_pad_mask(olens).unsqueeze(-1).to(spcs.device)
spcs = spcs.masked_select(mask)
cbhg_outs = cbhg_outs.masked_select(mask)
# calculate loss
cbhg_l1_loss = F.l1_loss(cbhg_outs, spcs)
cbhg_mse_loss = F.mse_loss(cbhg_outs, spcs)
return cbhg_l1_loss, cbhg_mse_loss
class CBHG(torch.nn.Module):
"""CBHG module to convert log Mel-filterbanks to linear spectrogram.
This is a module of CBHG introduced
in `Tacotron: Towards End-to-End Speech Synthesis`_.
The CBHG converts the sequence of log Mel-filterbanks into linear spectrogram.
.. _`Tacotron: Towards End-to-End Speech Synthesis`:
https://arxiv.org/abs/1703.10135
"""
def __init__(
self,
idim,
odim,
conv_bank_layers=8,
conv_bank_chans=128,
conv_proj_filts=3,
conv_proj_chans=256,
highway_layers=4,
highway_units=128,
gru_units=256,
):
"""Initialize CBHG module.
Args:
idim (int): Dimension of the inputs.
odim (int): Dimension of the outputs.
conv_bank_layers (int, optional): The number of convolution bank layers.
conv_bank_chans (int, optional): The number of channels in convolution bank.
conv_proj_filts (int, optional):
Kernel size of convolutional projection layer.
conv_proj_chans (int, optional):
The number of channels in convolutional projection layer.
highway_layers (int, optional): The number of highway network layers.
highway_units (int, optional): The number of highway network units.
gru_units (int, optional): The number of GRU units (for both directions).
"""
super(CBHG, self).__init__()
self.idim = idim
self.odim = odim
self.conv_bank_layers = conv_bank_layers
self.conv_bank_chans = conv_bank_chans
self.conv_proj_filts = conv_proj_filts
self.conv_proj_chans = conv_proj_chans
self.highway_layers = highway_layers
self.highway_units = highway_units
self.gru_units = gru_units
# define 1d convolution bank
self.conv_bank = torch.nn.ModuleList()
for k in range(1, self.conv_bank_layers + 1):
if k % 2 != 0:
padding = (k - 1) // 2
else:
padding = ((k - 1) // 2, (k - 1) // 2 + 1)
self.conv_bank += [
torch.nn.Sequential(
torch.nn.ConstantPad1d(padding, 0.0),
torch.nn.Conv1d(
idim, self.conv_bank_chans, k, stride=1, padding=0, bias=True
),
torch.nn.BatchNorm1d(self.conv_bank_chans),
torch.nn.ReLU(),
)
]
# define max pooling (need padding for one-side to keep same length)
self.max_pool = torch.nn.Sequential(
torch.nn.ConstantPad1d((0, 1), 0.0), torch.nn.MaxPool1d(2, stride=1)
)
# define 1d convolution projection
self.projections = torch.nn.Sequential(
torch.nn.Conv1d(
self.conv_bank_chans * self.conv_bank_layers,
self.conv_proj_chans,
self.conv_proj_filts,
stride=1,
padding=(self.conv_proj_filts - 1) // 2,
bias=True,
),
torch.nn.BatchNorm1d(self.conv_proj_chans),
torch.nn.ReLU(),
torch.nn.Conv1d(
self.conv_proj_chans,
self.idim,
self.conv_proj_filts,
stride=1,
padding=(self.conv_proj_filts - 1) // 2,
bias=True,
),
torch.nn.BatchNorm1d(self.idim),
)
# define highway network
self.highways = torch.nn.ModuleList()
self.highways += [torch.nn.Linear(idim, self.highway_units)]
for _ in range(self.highway_layers):
self.highways += [HighwayNet(self.highway_units)]
# define bidirectional GRU
self.gru = torch.nn.GRU(
self.highway_units,
gru_units // 2,
num_layers=1,
batch_first=True,
bidirectional=True,
)
# define final projection
self.output = torch.nn.Linear(gru_units, odim, bias=True)
def forward(self, xs, ilens):
"""Calculate forward propagation.
Args:
xs (Tensor): Batch of the padded sequences of inputs (B, Tmax, idim).
ilens (LongTensor): Batch of lengths of each input sequence (B,).
Return:
Tensor: Batch of the padded sequence of outputs (B, Tmax, odim).
LongTensor: Batch of lengths of each output sequence (B,).
"""
xs = xs.transpose(1, 2) # (B, idim, Tmax)
convs = []
for k in range(self.conv_bank_layers):
convs += [self.conv_bank[k](xs)]
convs = torch.cat(convs, dim=1) # (B, #CH * #BANK, Tmax)
convs = self.max_pool(convs)
convs = self.projections(convs).transpose(1, 2) # (B, Tmax, idim)
xs = xs.transpose(1, 2) + convs
# + 1 for dimension adjustment layer
for i in range(self.highway_layers + 1):
xs = self.highways[i](xs)
# sort by length
xs, ilens, sort_idx = self._sort_by_length(xs, ilens)
# total_length needs for DataParallel
# (see https://github.com/pytorch/pytorch/pull/6327)
total_length = xs.size(1)
if not isinstance(ilens, torch.Tensor):
ilens = torch.tensor(ilens)
xs = pack_padded_sequence(xs, ilens.cpu(), batch_first=True)
self.gru.flatten_parameters()
xs, _ = self.gru(xs)
xs, ilens = pad_packed_sequence(xs, batch_first=True, total_length=total_length)
# revert sorting by length
xs, ilens = self._revert_sort_by_length(xs, ilens, sort_idx)
xs = self.output(xs) # (B, Tmax, odim)
return xs, ilens
def inference(self, x):
"""Inference.
Args:
x (Tensor): The sequences of inputs (T, idim).
Return:
Tensor: The sequence of outputs (T, odim).
"""
assert len(x.size()) == 2
xs = x.unsqueeze(0)
ilens = x.new([x.size(0)]).long()
return self.forward(xs, ilens)[0][0]
def _sort_by_length(self, xs, ilens):
sort_ilens, sort_idx = ilens.sort(0, descending=True)
return xs[sort_idx], ilens[sort_idx], sort_idx
def _revert_sort_by_length(self, xs, ilens, sort_idx):
_, revert_idx = sort_idx.sort(0)
return xs[revert_idx], ilens[revert_idx]
class HighwayNet(torch.nn.Module):
"""Highway Network module.
This is a module of Highway Network introduced in `Highway Networks`_.
.. _`Highway Networks`: https://arxiv.org/abs/1505.00387
"""
def __init__(self, idim):
"""Initialize Highway Network module.
Args:
idim (int): Dimension of the inputs.
"""
super(HighwayNet, self).__init__()
self.idim = idim
self.projection = torch.nn.Sequential(
torch.nn.Linear(idim, idim), torch.nn.ReLU()
)
self.gate = torch.nn.Sequential(torch.nn.Linear(idim, idim), torch.nn.Sigmoid())
def forward(self, x):
"""Calculate forward propagation.
Args:
x (Tensor): Batch of inputs (B, ..., idim).
Returns:
Tensor: Batch of outputs, which are the same shape as inputs (B, ..., idim).
"""
proj = self.projection(x)
gate = self.gate(x)
return proj * gate + x * (1.0 - gate)
|