|
|
|
|
|
|
|
"""TTS-Transformer related modules.""" |
|
|
|
import logging |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
|
|
from espnet.nets.pytorch_backend.e2e_tts_tacotron2 import GuidedAttentionLoss |
|
from espnet.nets.pytorch_backend.e2e_tts_tacotron2 import ( |
|
Tacotron2Loss as TransformerLoss, |
|
) |
|
from espnet.nets.pytorch_backend.nets_utils import make_non_pad_mask |
|
from espnet.nets.pytorch_backend.tacotron2.decoder import Postnet |
|
from espnet.nets.pytorch_backend.tacotron2.decoder import Prenet as DecoderPrenet |
|
from espnet.nets.pytorch_backend.tacotron2.encoder import Encoder as EncoderPrenet |
|
from espnet.nets.pytorch_backend.transformer.attention import MultiHeadedAttention |
|
from espnet.nets.pytorch_backend.transformer.decoder import Decoder |
|
from espnet.nets.pytorch_backend.transformer.embedding import PositionalEncoding |
|
from espnet.nets.pytorch_backend.transformer.embedding import ScaledPositionalEncoding |
|
from espnet.nets.pytorch_backend.transformer.encoder import Encoder |
|
from espnet.nets.pytorch_backend.transformer.initializer import initialize |
|
from espnet.nets.pytorch_backend.transformer.mask import subsequent_mask |
|
from espnet.nets.tts_interface import TTSInterface |
|
from espnet.utils.cli_utils import strtobool |
|
from espnet.utils.fill_missing_args import fill_missing_args |
|
|
|
|
|
class GuidedMultiHeadAttentionLoss(GuidedAttentionLoss): |
|
"""Guided attention loss function module for multi head attention. |
|
|
|
Args: |
|
sigma (float, optional): Standard deviation to control |
|
how close attention to a diagonal. |
|
alpha (float, optional): Scaling coefficient (lambda). |
|
reset_always (bool, optional): Whether to always reset masks. |
|
|
|
""" |
|
|
|
def forward(self, att_ws, ilens, olens): |
|
"""Calculate forward propagation. |
|
|
|
Args: |
|
att_ws (Tensor): |
|
Batch of multi head attention weights (B, H, T_max_out, T_max_in). |
|
ilens (LongTensor): Batch of input lenghts (B,). |
|
olens (LongTensor): Batch of output lenghts (B,). |
|
|
|
Returns: |
|
Tensor: Guided attention loss value. |
|
|
|
""" |
|
if self.guided_attn_masks is None: |
|
self.guided_attn_masks = ( |
|
self._make_guided_attention_masks(ilens, olens) |
|
.to(att_ws.device) |
|
.unsqueeze(1) |
|
) |
|
if self.masks is None: |
|
self.masks = self._make_masks(ilens, olens).to(att_ws.device).unsqueeze(1) |
|
losses = self.guided_attn_masks * att_ws |
|
loss = torch.mean(losses.masked_select(self.masks)) |
|
if self.reset_always: |
|
self._reset_masks() |
|
|
|
return self.alpha * loss |
|
|
|
|
|
try: |
|
from espnet.nets.pytorch_backend.transformer.plot import PlotAttentionReport |
|
except (ImportError, TypeError): |
|
TTSPlot = None |
|
else: |
|
|
|
class TTSPlot(PlotAttentionReport): |
|
"""Attention plot module for TTS-Transformer.""" |
|
|
|
def plotfn( |
|
self, data_dict, uttid_list, attn_dict, outdir, suffix="png", savefn=None |
|
): |
|
"""Plot multi head attentions. |
|
|
|
Args: |
|
data_dict (dict): Utts info from json file. |
|
uttid_list (list): List of utt_id. |
|
attn_dict (dict): Multi head attention dict. |
|
Values should be numpy.ndarray (H, L, T) |
|
outdir (str): Directory name to save figures. |
|
suffix (str): Filename suffix including image type (e.g., png). |
|
savefn (function): Function to save figures. |
|
|
|
""" |
|
import matplotlib.pyplot as plt |
|
from espnet.nets.pytorch_backend.transformer.plot import ( |
|
_plot_and_save_attention, |
|
) |
|
|
|
for name, att_ws in attn_dict.items(): |
|
for utt_id, att_w in zip(uttid_list, att_ws): |
|
filename = "%s/%s.%s.%s" % (outdir, utt_id, name, suffix) |
|
if "fbank" in name: |
|
fig = plt.Figure() |
|
ax = fig.subplots(1, 1) |
|
ax.imshow(att_w, aspect="auto") |
|
ax.set_xlabel("frames") |
|
ax.set_ylabel("fbank coeff") |
|
fig.tight_layout() |
|
else: |
|
fig = _plot_and_save_attention(att_w, filename) |
|
savefn(fig, filename) |
|
|
|
|
|
class Transformer(TTSInterface, torch.nn.Module): |
|
"""Text-to-Speech Transformer module. |
|
|
|
This is a module of text-to-speech Transformer described |
|
in `Neural Speech Synthesis with Transformer Network`_, |
|
which convert the sequence of characters |
|
or phonemes into the sequence of Mel-filterbanks. |
|
|
|
.. _`Neural Speech Synthesis with Transformer Network`: |
|
https://arxiv.org/pdf/1809.08895.pdf |
|
|
|
""" |
|
|
|
@staticmethod |
|
def add_arguments(parser): |
|
"""Add model-specific arguments to the parser.""" |
|
group = parser.add_argument_group("transformer model setting") |
|
|
|
group.add_argument( |
|
"--embed-dim", |
|
default=512, |
|
type=int, |
|
help="Dimension of character embedding in encoder prenet", |
|
) |
|
group.add_argument( |
|
"--eprenet-conv-layers", |
|
default=3, |
|
type=int, |
|
help="Number of encoder prenet convolution layers", |
|
) |
|
group.add_argument( |
|
"--eprenet-conv-chans", |
|
default=256, |
|
type=int, |
|
help="Number of encoder prenet convolution channels", |
|
) |
|
group.add_argument( |
|
"--eprenet-conv-filts", |
|
default=5, |
|
type=int, |
|
help="Filter size of encoder prenet convolution", |
|
) |
|
group.add_argument( |
|
"--dprenet-layers", |
|
default=2, |
|
type=int, |
|
help="Number of decoder prenet layers", |
|
) |
|
group.add_argument( |
|
"--dprenet-units", |
|
default=256, |
|
type=int, |
|
help="Number of decoder prenet hidden units", |
|
) |
|
group.add_argument( |
|
"--elayers", default=3, type=int, help="Number of encoder layers" |
|
) |
|
group.add_argument( |
|
"--eunits", default=1536, type=int, help="Number of encoder hidden units" |
|
) |
|
group.add_argument( |
|
"--adim", |
|
default=384, |
|
type=int, |
|
help="Number of attention transformation dimensions", |
|
) |
|
group.add_argument( |
|
"--aheads", |
|
default=4, |
|
type=int, |
|
help="Number of heads for multi head attention", |
|
) |
|
group.add_argument( |
|
"--dlayers", default=3, type=int, help="Number of decoder layers" |
|
) |
|
group.add_argument( |
|
"--dunits", default=1536, type=int, help="Number of decoder hidden units" |
|
) |
|
group.add_argument( |
|
"--positionwise-layer-type", |
|
default="linear", |
|
type=str, |
|
choices=["linear", "conv1d", "conv1d-linear"], |
|
help="Positionwise layer type.", |
|
) |
|
group.add_argument( |
|
"--positionwise-conv-kernel-size", |
|
default=1, |
|
type=int, |
|
help="Kernel size of positionwise conv1d layer", |
|
) |
|
group.add_argument( |
|
"--postnet-layers", default=5, type=int, help="Number of postnet layers" |
|
) |
|
group.add_argument( |
|
"--postnet-chans", default=256, type=int, help="Number of postnet channels" |
|
) |
|
group.add_argument( |
|
"--postnet-filts", default=5, type=int, help="Filter size of postnet" |
|
) |
|
group.add_argument( |
|
"--use-scaled-pos-enc", |
|
default=True, |
|
type=strtobool, |
|
help="Use trainable scaled positional encoding " |
|
"instead of the fixed scale one.", |
|
) |
|
group.add_argument( |
|
"--use-batch-norm", |
|
default=True, |
|
type=strtobool, |
|
help="Whether to use batch normalization", |
|
) |
|
group.add_argument( |
|
"--encoder-normalize-before", |
|
default=False, |
|
type=strtobool, |
|
help="Whether to apply layer norm before encoder block", |
|
) |
|
group.add_argument( |
|
"--decoder-normalize-before", |
|
default=False, |
|
type=strtobool, |
|
help="Whether to apply layer norm before decoder block", |
|
) |
|
group.add_argument( |
|
"--encoder-concat-after", |
|
default=False, |
|
type=strtobool, |
|
help="Whether to concatenate attention layer's input and output in encoder", |
|
) |
|
group.add_argument( |
|
"--decoder-concat-after", |
|
default=False, |
|
type=strtobool, |
|
help="Whether to concatenate attention layer's input and output in decoder", |
|
) |
|
group.add_argument( |
|
"--reduction-factor", default=1, type=int, help="Reduction factor" |
|
) |
|
group.add_argument( |
|
"--spk-embed-dim", |
|
default=None, |
|
type=int, |
|
help="Number of speaker embedding dimensions", |
|
) |
|
group.add_argument( |
|
"--spk-embed-integration-type", |
|
type=str, |
|
default="add", |
|
choices=["add", "concat"], |
|
help="How to integrate speaker embedding", |
|
) |
|
|
|
group.add_argument( |
|
"--transformer-init", |
|
type=str, |
|
default="pytorch", |
|
choices=[ |
|
"pytorch", |
|
"xavier_uniform", |
|
"xavier_normal", |
|
"kaiming_uniform", |
|
"kaiming_normal", |
|
], |
|
help="How to initialize transformer parameters", |
|
) |
|
group.add_argument( |
|
"--initial-encoder-alpha", |
|
type=float, |
|
default=1.0, |
|
help="Initial alpha value in encoder's ScaledPositionalEncoding", |
|
) |
|
group.add_argument( |
|
"--initial-decoder-alpha", |
|
type=float, |
|
default=1.0, |
|
help="Initial alpha value in decoder's ScaledPositionalEncoding", |
|
) |
|
group.add_argument( |
|
"--transformer-lr", |
|
default=1.0, |
|
type=float, |
|
help="Initial value of learning rate", |
|
) |
|
group.add_argument( |
|
"--transformer-warmup-steps", |
|
default=4000, |
|
type=int, |
|
help="Optimizer warmup steps", |
|
) |
|
group.add_argument( |
|
"--transformer-enc-dropout-rate", |
|
default=0.1, |
|
type=float, |
|
help="Dropout rate for transformer encoder except for attention", |
|
) |
|
group.add_argument( |
|
"--transformer-enc-positional-dropout-rate", |
|
default=0.1, |
|
type=float, |
|
help="Dropout rate for transformer encoder positional encoding", |
|
) |
|
group.add_argument( |
|
"--transformer-enc-attn-dropout-rate", |
|
default=0.1, |
|
type=float, |
|
help="Dropout rate for transformer encoder self-attention", |
|
) |
|
group.add_argument( |
|
"--transformer-dec-dropout-rate", |
|
default=0.1, |
|
type=float, |
|
help="Dropout rate for transformer decoder " |
|
"except for attention and pos encoding", |
|
) |
|
group.add_argument( |
|
"--transformer-dec-positional-dropout-rate", |
|
default=0.1, |
|
type=float, |
|
help="Dropout rate for transformer decoder positional encoding", |
|
) |
|
group.add_argument( |
|
"--transformer-dec-attn-dropout-rate", |
|
default=0.1, |
|
type=float, |
|
help="Dropout rate for transformer decoder self-attention", |
|
) |
|
group.add_argument( |
|
"--transformer-enc-dec-attn-dropout-rate", |
|
default=0.1, |
|
type=float, |
|
help="Dropout rate for transformer encoder-decoder attention", |
|
) |
|
group.add_argument( |
|
"--eprenet-dropout-rate", |
|
default=0.5, |
|
type=float, |
|
help="Dropout rate in encoder prenet", |
|
) |
|
group.add_argument( |
|
"--dprenet-dropout-rate", |
|
default=0.5, |
|
type=float, |
|
help="Dropout rate in decoder prenet", |
|
) |
|
group.add_argument( |
|
"--postnet-dropout-rate", |
|
default=0.5, |
|
type=float, |
|
help="Dropout rate in postnet", |
|
) |
|
group.add_argument( |
|
"--pretrained-model", default=None, type=str, help="Pretrained model path" |
|
) |
|
|
|
group.add_argument( |
|
"--use-masking", |
|
default=True, |
|
type=strtobool, |
|
help="Whether to use masking in calculation of loss", |
|
) |
|
group.add_argument( |
|
"--use-weighted-masking", |
|
default=False, |
|
type=strtobool, |
|
help="Whether to use weighted masking in calculation of loss", |
|
) |
|
group.add_argument( |
|
"--loss-type", |
|
default="L1", |
|
choices=["L1", "L2", "L1+L2"], |
|
help="How to calc loss", |
|
) |
|
group.add_argument( |
|
"--bce-pos-weight", |
|
default=5.0, |
|
type=float, |
|
help="Positive sample weight in BCE calculation " |
|
"(only for use-masking=True)", |
|
) |
|
group.add_argument( |
|
"--use-guided-attn-loss", |
|
default=False, |
|
type=strtobool, |
|
help="Whether to use guided attention loss", |
|
) |
|
group.add_argument( |
|
"--guided-attn-loss-sigma", |
|
default=0.4, |
|
type=float, |
|
help="Sigma in guided attention loss", |
|
) |
|
group.add_argument( |
|
"--guided-attn-loss-lambda", |
|
default=1.0, |
|
type=float, |
|
help="Lambda in guided attention loss", |
|
) |
|
group.add_argument( |
|
"--num-heads-applied-guided-attn", |
|
default=2, |
|
type=int, |
|
help="Number of heads in each layer to be applied guided attention loss" |
|
"if set -1, all of the heads will be applied.", |
|
) |
|
group.add_argument( |
|
"--num-layers-applied-guided-attn", |
|
default=2, |
|
type=int, |
|
help="Number of layers to be applied guided attention loss" |
|
"if set -1, all of the layers will be applied.", |
|
) |
|
group.add_argument( |
|
"--modules-applied-guided-attn", |
|
type=str, |
|
nargs="+", |
|
default=["encoder-decoder"], |
|
help="Module name list to be applied guided attention loss", |
|
) |
|
return parser |
|
|
|
@property |
|
def attention_plot_class(self): |
|
"""Return plot class for attention weight plot.""" |
|
return TTSPlot |
|
|
|
def __init__(self, idim, odim, args=None): |
|
"""Initialize TTS-Transformer module. |
|
|
|
Args: |
|
idim (int): Dimension of the inputs. |
|
odim (int): Dimension of the outputs. |
|
args (Namespace, optional): |
|
- embed_dim (int): Dimension of character embedding. |
|
- eprenet_conv_layers (int): |
|
Number of encoder prenet convolution layers. |
|
- eprenet_conv_chans (int): |
|
Number of encoder prenet convolution channels. |
|
- eprenet_conv_filts (int): Filter size of encoder prenet convolution. |
|
- dprenet_layers (int): Number of decoder prenet layers. |
|
- dprenet_units (int): Number of decoder prenet hidden units. |
|
- elayers (int): Number of encoder layers. |
|
- eunits (int): Number of encoder hidden units. |
|
- adim (int): Number of attention transformation dimensions. |
|
- aheads (int): Number of heads for multi head attention. |
|
- dlayers (int): Number of decoder layers. |
|
- dunits (int): Number of decoder hidden units. |
|
- postnet_layers (int): Number of postnet layers. |
|
- postnet_chans (int): Number of postnet channels. |
|
- postnet_filts (int): Filter size of postnet. |
|
- use_scaled_pos_enc (bool): |
|
Whether to use trainable scaled positional encoding. |
|
- use_batch_norm (bool): |
|
Whether to use batch normalization in encoder prenet. |
|
- encoder_normalize_before (bool): |
|
Whether to perform layer normalization before encoder block. |
|
- decoder_normalize_before (bool): |
|
Whether to perform layer normalization before decoder block. |
|
- encoder_concat_after (bool): Whether to concatenate attention |
|
layer's input and output in encoder. |
|
- decoder_concat_after (bool): Whether to concatenate attention |
|
layer's input and output in decoder. |
|
- reduction_factor (int): Reduction factor. |
|
- spk_embed_dim (int): Number of speaker embedding dimenstions. |
|
- spk_embed_integration_type: How to integrate speaker embedding. |
|
- transformer_init (float): How to initialize transformer parameters. |
|
- transformer_lr (float): Initial value of learning rate. |
|
- transformer_warmup_steps (int): Optimizer warmup steps. |
|
- transformer_enc_dropout_rate (float): |
|
Dropout rate in encoder except attention & positional encoding. |
|
- transformer_enc_positional_dropout_rate (float): |
|
Dropout rate after encoder positional encoding. |
|
- transformer_enc_attn_dropout_rate (float): |
|
Dropout rate in encoder self-attention module. |
|
- transformer_dec_dropout_rate (float): |
|
Dropout rate in decoder except attention & positional encoding. |
|
- transformer_dec_positional_dropout_rate (float): |
|
Dropout rate after decoder positional encoding. |
|
- transformer_dec_attn_dropout_rate (float): |
|
Dropout rate in deocoder self-attention module. |
|
- transformer_enc_dec_attn_dropout_rate (float): |
|
Dropout rate in encoder-deocoder attention module. |
|
- eprenet_dropout_rate (float): Dropout rate in encoder prenet. |
|
- dprenet_dropout_rate (float): Dropout rate in decoder prenet. |
|
- postnet_dropout_rate (float): Dropout rate in postnet. |
|
- use_masking (bool): |
|
Whether to apply masking for padded part in loss calculation. |
|
- use_weighted_masking (bool): |
|
Whether to apply weighted masking in loss calculation. |
|
- bce_pos_weight (float): Positive sample weight in bce calculation |
|
(only for use_masking=true). |
|
- loss_type (str): How to calculate loss. |
|
- use_guided_attn_loss (bool): Whether to use guided attention loss. |
|
- num_heads_applied_guided_attn (int): |
|
Number of heads in each layer to apply guided attention loss. |
|
- num_layers_applied_guided_attn (int): |
|
Number of layers to apply guided attention loss. |
|
- modules_applied_guided_attn (list): |
|
List of module names to apply guided attention loss. |
|
- guided-attn-loss-sigma (float) Sigma in guided attention loss. |
|
- guided-attn-loss-lambda (float): Lambda in guided attention loss. |
|
|
|
""" |
|
|
|
TTSInterface.__init__(self) |
|
torch.nn.Module.__init__(self) |
|
|
|
|
|
args = fill_missing_args(args, self.add_arguments) |
|
|
|
|
|
self.idim = idim |
|
self.odim = odim |
|
self.spk_embed_dim = args.spk_embed_dim |
|
if self.spk_embed_dim is not None: |
|
self.spk_embed_integration_type = args.spk_embed_integration_type |
|
self.use_scaled_pos_enc = args.use_scaled_pos_enc |
|
self.reduction_factor = args.reduction_factor |
|
self.loss_type = args.loss_type |
|
self.use_guided_attn_loss = args.use_guided_attn_loss |
|
if self.use_guided_attn_loss: |
|
if args.num_layers_applied_guided_attn == -1: |
|
self.num_layers_applied_guided_attn = args.elayers |
|
else: |
|
self.num_layers_applied_guided_attn = ( |
|
args.num_layers_applied_guided_attn |
|
) |
|
if args.num_heads_applied_guided_attn == -1: |
|
self.num_heads_applied_guided_attn = args.aheads |
|
else: |
|
self.num_heads_applied_guided_attn = args.num_heads_applied_guided_attn |
|
self.modules_applied_guided_attn = args.modules_applied_guided_attn |
|
|
|
|
|
padding_idx = 0 |
|
|
|
|
|
pos_enc_class = ( |
|
ScaledPositionalEncoding if self.use_scaled_pos_enc else PositionalEncoding |
|
) |
|
|
|
|
|
if args.eprenet_conv_layers != 0: |
|
|
|
encoder_input_layer = torch.nn.Sequential( |
|
EncoderPrenet( |
|
idim=idim, |
|
embed_dim=args.embed_dim, |
|
elayers=0, |
|
econv_layers=args.eprenet_conv_layers, |
|
econv_chans=args.eprenet_conv_chans, |
|
econv_filts=args.eprenet_conv_filts, |
|
use_batch_norm=args.use_batch_norm, |
|
dropout_rate=args.eprenet_dropout_rate, |
|
padding_idx=padding_idx, |
|
), |
|
torch.nn.Linear(args.eprenet_conv_chans, args.adim), |
|
) |
|
else: |
|
encoder_input_layer = torch.nn.Embedding( |
|
num_embeddings=idim, embedding_dim=args.adim, padding_idx=padding_idx |
|
) |
|
self.encoder = Encoder( |
|
idim=idim, |
|
attention_dim=args.adim, |
|
attention_heads=args.aheads, |
|
linear_units=args.eunits, |
|
num_blocks=args.elayers, |
|
input_layer=encoder_input_layer, |
|
dropout_rate=args.transformer_enc_dropout_rate, |
|
positional_dropout_rate=args.transformer_enc_positional_dropout_rate, |
|
attention_dropout_rate=args.transformer_enc_attn_dropout_rate, |
|
pos_enc_class=pos_enc_class, |
|
normalize_before=args.encoder_normalize_before, |
|
concat_after=args.encoder_concat_after, |
|
positionwise_layer_type=args.positionwise_layer_type, |
|
positionwise_conv_kernel_size=args.positionwise_conv_kernel_size, |
|
) |
|
|
|
|
|
if self.spk_embed_dim is not None: |
|
if self.spk_embed_integration_type == "add": |
|
self.projection = torch.nn.Linear(self.spk_embed_dim, args.adim) |
|
else: |
|
self.projection = torch.nn.Linear( |
|
args.adim + self.spk_embed_dim, args.adim |
|
) |
|
|
|
|
|
if args.dprenet_layers != 0: |
|
|
|
decoder_input_layer = torch.nn.Sequential( |
|
DecoderPrenet( |
|
idim=odim, |
|
n_layers=args.dprenet_layers, |
|
n_units=args.dprenet_units, |
|
dropout_rate=args.dprenet_dropout_rate, |
|
), |
|
torch.nn.Linear(args.dprenet_units, args.adim), |
|
) |
|
else: |
|
decoder_input_layer = "linear" |
|
self.decoder = Decoder( |
|
odim=-1, |
|
attention_dim=args.adim, |
|
attention_heads=args.aheads, |
|
linear_units=args.dunits, |
|
num_blocks=args.dlayers, |
|
dropout_rate=args.transformer_dec_dropout_rate, |
|
positional_dropout_rate=args.transformer_dec_positional_dropout_rate, |
|
self_attention_dropout_rate=args.transformer_dec_attn_dropout_rate, |
|
src_attention_dropout_rate=args.transformer_enc_dec_attn_dropout_rate, |
|
input_layer=decoder_input_layer, |
|
use_output_layer=False, |
|
pos_enc_class=pos_enc_class, |
|
normalize_before=args.decoder_normalize_before, |
|
concat_after=args.decoder_concat_after, |
|
) |
|
|
|
|
|
self.feat_out = torch.nn.Linear(args.adim, odim * args.reduction_factor) |
|
self.prob_out = torch.nn.Linear(args.adim, args.reduction_factor) |
|
|
|
|
|
self.postnet = ( |
|
None |
|
if args.postnet_layers == 0 |
|
else Postnet( |
|
idim=idim, |
|
odim=odim, |
|
n_layers=args.postnet_layers, |
|
n_chans=args.postnet_chans, |
|
n_filts=args.postnet_filts, |
|
use_batch_norm=args.use_batch_norm, |
|
dropout_rate=args.postnet_dropout_rate, |
|
) |
|
) |
|
|
|
|
|
self.criterion = TransformerLoss( |
|
use_masking=args.use_masking, |
|
use_weighted_masking=args.use_weighted_masking, |
|
bce_pos_weight=args.bce_pos_weight, |
|
) |
|
if self.use_guided_attn_loss: |
|
self.attn_criterion = GuidedMultiHeadAttentionLoss( |
|
sigma=args.guided_attn_loss_sigma, |
|
alpha=args.guided_attn_loss_lambda, |
|
) |
|
|
|
|
|
self._reset_parameters( |
|
init_type=args.transformer_init, |
|
init_enc_alpha=args.initial_encoder_alpha, |
|
init_dec_alpha=args.initial_decoder_alpha, |
|
) |
|
|
|
|
|
if args.pretrained_model is not None: |
|
self.load_pretrained_model(args.pretrained_model) |
|
|
|
def _reset_parameters(self, init_type, init_enc_alpha=1.0, init_dec_alpha=1.0): |
|
|
|
initialize(self, init_type) |
|
|
|
|
|
if self.use_scaled_pos_enc: |
|
self.encoder.embed[-1].alpha.data = torch.tensor(init_enc_alpha) |
|
self.decoder.embed[-1].alpha.data = torch.tensor(init_dec_alpha) |
|
|
|
def _add_first_frame_and_remove_last_frame(self, ys): |
|
ys_in = torch.cat( |
|
[ys.new_zeros((ys.shape[0], 1, ys.shape[2])), ys[:, :-1]], dim=1 |
|
) |
|
return ys_in |
|
|
|
def forward(self, xs, ilens, ys, labels, olens, spembs=None, *args, **kwargs): |
|
"""Calculate forward propagation. |
|
|
|
Args: |
|
xs (Tensor): Batch of padded character ids (B, Tmax). |
|
ilens (LongTensor): Batch of lengths of each input batch (B,). |
|
ys (Tensor): Batch of padded target features (B, Lmax, odim). |
|
olens (LongTensor): Batch of the lengths of each target (B,). |
|
spembs (Tensor, optional): |
|
Batch of speaker embedding vectors (B, spk_embed_dim). |
|
|
|
Returns: |
|
Tensor: Loss value. |
|
|
|
""" |
|
|
|
max_ilen = max(ilens) |
|
max_olen = max(olens) |
|
if max_ilen != xs.shape[1]: |
|
xs = xs[:, :max_ilen] |
|
if max_olen != ys.shape[1]: |
|
ys = ys[:, :max_olen] |
|
labels = labels[:, :max_olen] |
|
|
|
|
|
x_masks = self._source_mask(ilens) |
|
hs, h_masks = self.encoder(xs, x_masks) |
|
|
|
|
|
if self.spk_embed_dim is not None: |
|
hs = self._integrate_with_spk_embed(hs, spembs) |
|
|
|
|
|
if self.reduction_factor > 1: |
|
ys_in = ys[:, self.reduction_factor - 1 :: self.reduction_factor] |
|
olens_in = olens.new([olen // self.reduction_factor for olen in olens]) |
|
else: |
|
ys_in, olens_in = ys, olens |
|
|
|
|
|
ys_in = self._add_first_frame_and_remove_last_frame(ys_in) |
|
|
|
|
|
y_masks = self._target_mask(olens_in) |
|
zs, _ = self.decoder(ys_in, y_masks, hs, h_masks) |
|
|
|
before_outs = self.feat_out(zs).view(zs.size(0), -1, self.odim) |
|
|
|
logits = self.prob_out(zs).view(zs.size(0), -1) |
|
|
|
|
|
if self.postnet is None: |
|
after_outs = before_outs |
|
else: |
|
after_outs = before_outs + self.postnet( |
|
before_outs.transpose(1, 2) |
|
).transpose(1, 2) |
|
|
|
|
|
if self.reduction_factor > 1: |
|
olens = olens.new([olen - olen % self.reduction_factor for olen in olens]) |
|
max_olen = max(olens) |
|
ys = ys[:, :max_olen] |
|
labels = labels[:, :max_olen] |
|
labels[:, -1] = 1.0 |
|
|
|
|
|
l1_loss, l2_loss, bce_loss = self.criterion( |
|
after_outs, before_outs, logits, ys, labels, olens |
|
) |
|
if self.loss_type == "L1": |
|
loss = l1_loss + bce_loss |
|
elif self.loss_type == "L2": |
|
loss = l2_loss + bce_loss |
|
elif self.loss_type == "L1+L2": |
|
loss = l1_loss + l2_loss + bce_loss |
|
else: |
|
raise ValueError("unknown --loss-type " + self.loss_type) |
|
report_keys = [ |
|
{"l1_loss": l1_loss.item()}, |
|
{"l2_loss": l2_loss.item()}, |
|
{"bce_loss": bce_loss.item()}, |
|
{"loss": loss.item()}, |
|
] |
|
|
|
|
|
if self.use_guided_attn_loss: |
|
|
|
if "encoder" in self.modules_applied_guided_attn: |
|
att_ws = [] |
|
for idx, layer_idx in enumerate( |
|
reversed(range(len(self.encoder.encoders))) |
|
): |
|
att_ws += [ |
|
self.encoder.encoders[layer_idx].self_attn.attn[ |
|
:, : self.num_heads_applied_guided_attn |
|
] |
|
] |
|
if idx + 1 == self.num_layers_applied_guided_attn: |
|
break |
|
att_ws = torch.cat(att_ws, dim=1) |
|
enc_attn_loss = self.attn_criterion(att_ws, ilens, ilens) |
|
loss = loss + enc_attn_loss |
|
report_keys += [{"enc_attn_loss": enc_attn_loss.item()}] |
|
|
|
if "decoder" in self.modules_applied_guided_attn: |
|
att_ws = [] |
|
for idx, layer_idx in enumerate( |
|
reversed(range(len(self.decoder.decoders))) |
|
): |
|
att_ws += [ |
|
self.decoder.decoders[layer_idx].self_attn.attn[ |
|
:, : self.num_heads_applied_guided_attn |
|
] |
|
] |
|
if idx + 1 == self.num_layers_applied_guided_attn: |
|
break |
|
att_ws = torch.cat(att_ws, dim=1) |
|
dec_attn_loss = self.attn_criterion(att_ws, olens_in, olens_in) |
|
loss = loss + dec_attn_loss |
|
report_keys += [{"dec_attn_loss": dec_attn_loss.item()}] |
|
|
|
if "encoder-decoder" in self.modules_applied_guided_attn: |
|
att_ws = [] |
|
for idx, layer_idx in enumerate( |
|
reversed(range(len(self.decoder.decoders))) |
|
): |
|
att_ws += [ |
|
self.decoder.decoders[layer_idx].src_attn.attn[ |
|
:, : self.num_heads_applied_guided_attn |
|
] |
|
] |
|
if idx + 1 == self.num_layers_applied_guided_attn: |
|
break |
|
att_ws = torch.cat(att_ws, dim=1) |
|
enc_dec_attn_loss = self.attn_criterion(att_ws, ilens, olens_in) |
|
loss = loss + enc_dec_attn_loss |
|
report_keys += [{"enc_dec_attn_loss": enc_dec_attn_loss.item()}] |
|
|
|
|
|
if self.use_scaled_pos_enc: |
|
report_keys += [ |
|
{"encoder_alpha": self.encoder.embed[-1].alpha.data.item()}, |
|
{"decoder_alpha": self.decoder.embed[-1].alpha.data.item()}, |
|
] |
|
self.reporter.report(report_keys) |
|
|
|
return loss |
|
|
|
def inference(self, x, inference_args, spemb=None, *args, **kwargs): |
|
"""Generate the sequence of features given the sequences of characters. |
|
|
|
Args: |
|
x (Tensor): Input sequence of characters (T,). |
|
inference_args (Namespace): |
|
- threshold (float): Threshold in inference. |
|
- minlenratio (float): Minimum length ratio in inference. |
|
- maxlenratio (float): Maximum length ratio in inference. |
|
spemb (Tensor, optional): Speaker embedding vector (spk_embed_dim). |
|
|
|
Returns: |
|
Tensor: Output sequence of features (L, odim). |
|
Tensor: Output sequence of stop probabilities (L,). |
|
Tensor: Encoder-decoder (source) attention weights (#layers, #heads, L, T). |
|
|
|
""" |
|
|
|
threshold = inference_args.threshold |
|
minlenratio = inference_args.minlenratio |
|
maxlenratio = inference_args.maxlenratio |
|
use_att_constraint = getattr( |
|
inference_args, "use_att_constraint", False |
|
) |
|
if use_att_constraint: |
|
logging.warning( |
|
"Attention constraint is not yet supported in Transformer. Not enabled." |
|
) |
|
|
|
|
|
xs = x.unsqueeze(0) |
|
hs, _ = self.encoder(xs, None) |
|
|
|
|
|
if self.spk_embed_dim is not None: |
|
spembs = spemb.unsqueeze(0) |
|
hs = self._integrate_with_spk_embed(hs, spembs) |
|
|
|
|
|
maxlen = int(hs.size(1) * maxlenratio / self.reduction_factor) |
|
minlen = int(hs.size(1) * minlenratio / self.reduction_factor) |
|
|
|
|
|
idx = 0 |
|
ys = hs.new_zeros(1, 1, self.odim) |
|
outs, probs = [], [] |
|
|
|
|
|
z_cache = self.decoder.init_state(x) |
|
while True: |
|
|
|
idx += 1 |
|
|
|
|
|
y_masks = subsequent_mask(idx).unsqueeze(0).to(x.device) |
|
z, z_cache = self.decoder.forward_one_step( |
|
ys, y_masks, hs, cache=z_cache |
|
) |
|
outs += [ |
|
self.feat_out(z).view(self.reduction_factor, self.odim) |
|
] |
|
probs += [torch.sigmoid(self.prob_out(z))[0]] |
|
|
|
|
|
ys = torch.cat( |
|
(ys, outs[-1][-1].view(1, 1, self.odim)), dim=1 |
|
) |
|
|
|
|
|
att_ws_ = [] |
|
for name, m in self.named_modules(): |
|
if isinstance(m, MultiHeadedAttention) and "src" in name: |
|
att_ws_ += [m.attn[0, :, -1].unsqueeze(1)] |
|
if idx == 1: |
|
att_ws = att_ws_ |
|
else: |
|
|
|
att_ws = [ |
|
torch.cat([att_w, att_w_], dim=1) |
|
for att_w, att_w_ in zip(att_ws, att_ws_) |
|
] |
|
|
|
|
|
if int(sum(probs[-1] >= threshold)) > 0 or idx >= maxlen: |
|
|
|
if idx < minlen: |
|
continue |
|
outs = ( |
|
torch.cat(outs, dim=0).unsqueeze(0).transpose(1, 2) |
|
) |
|
if self.postnet is not None: |
|
outs = outs + self.postnet(outs) |
|
outs = outs.transpose(2, 1).squeeze(0) |
|
probs = torch.cat(probs, dim=0) |
|
break |
|
|
|
|
|
att_ws = torch.stack(att_ws, dim=0) |
|
|
|
return outs, probs, att_ws |
|
|
|
def calculate_all_attentions( |
|
self, |
|
xs, |
|
ilens, |
|
ys, |
|
olens, |
|
spembs=None, |
|
skip_output=False, |
|
keep_tensor=False, |
|
*args, |
|
**kwargs |
|
): |
|
"""Calculate all of the attention weights. |
|
|
|
Args: |
|
xs (Tensor): Batch of padded character ids (B, Tmax). |
|
ilens (LongTensor): Batch of lengths of each input batch (B,). |
|
ys (Tensor): Batch of padded target features (B, Lmax, odim). |
|
olens (LongTensor): Batch of the lengths of each target (B,). |
|
spembs (Tensor, optional): |
|
Batch of speaker embedding vectors (B, spk_embed_dim). |
|
skip_output (bool, optional): Whether to skip calculate the final output. |
|
keep_tensor (bool, optional): Whether to keep original tensor. |
|
|
|
Returns: |
|
dict: Dict of attention weights and outputs. |
|
|
|
""" |
|
self.eval() |
|
with torch.no_grad(): |
|
|
|
x_masks = self._source_mask(ilens) |
|
hs, h_masks = self.encoder(xs, x_masks) |
|
|
|
|
|
if self.spk_embed_dim is not None: |
|
hs = self._integrate_with_spk_embed(hs, spembs) |
|
|
|
|
|
|
|
if self.reduction_factor > 1: |
|
ys_in = ys[:, self.reduction_factor - 1 :: self.reduction_factor] |
|
olens_in = olens.new([olen // self.reduction_factor for olen in olens]) |
|
else: |
|
ys_in, olens_in = ys, olens |
|
|
|
|
|
ys_in = self._add_first_frame_and_remove_last_frame(ys_in) |
|
|
|
|
|
y_masks = self._target_mask(olens_in) |
|
zs, _ = self.decoder(ys_in, y_masks, hs, h_masks) |
|
|
|
|
|
if not skip_output: |
|
before_outs = self.feat_out(zs).view(zs.size(0), -1, self.odim) |
|
if self.postnet is None: |
|
after_outs = before_outs |
|
else: |
|
after_outs = before_outs + self.postnet( |
|
before_outs.transpose(1, 2) |
|
).transpose(1, 2) |
|
|
|
|
|
if self.reduction_factor > 1: |
|
olens = olens.new([olen - olen % self.reduction_factor for olen in olens]) |
|
|
|
|
|
att_ws_dict = dict() |
|
if keep_tensor: |
|
for name, m in self.named_modules(): |
|
if isinstance(m, MultiHeadedAttention): |
|
att_ws_dict[name] = m.attn |
|
if not skip_output: |
|
att_ws_dict["before_postnet_fbank"] = before_outs |
|
att_ws_dict["after_postnet_fbank"] = after_outs |
|
else: |
|
for name, m in self.named_modules(): |
|
if isinstance(m, MultiHeadedAttention): |
|
attn = m.attn.cpu().numpy() |
|
if "encoder" in name: |
|
attn = [a[:, :l, :l] for a, l in zip(attn, ilens.tolist())] |
|
elif "decoder" in name: |
|
if "src" in name: |
|
attn = [ |
|
a[:, :ol, :il] |
|
for a, il, ol in zip( |
|
attn, ilens.tolist(), olens_in.tolist() |
|
) |
|
] |
|
elif "self" in name: |
|
attn = [ |
|
a[:, :l, :l] for a, l in zip(attn, olens_in.tolist()) |
|
] |
|
else: |
|
logging.warning("unknown attention module: " + name) |
|
else: |
|
logging.warning("unknown attention module: " + name) |
|
att_ws_dict[name] = attn |
|
if not skip_output: |
|
before_outs = before_outs.cpu().numpy() |
|
after_outs = after_outs.cpu().numpy() |
|
att_ws_dict["before_postnet_fbank"] = [ |
|
m[:l].T for m, l in zip(before_outs, olens.tolist()) |
|
] |
|
att_ws_dict["after_postnet_fbank"] = [ |
|
m[:l].T for m, l in zip(after_outs, olens.tolist()) |
|
] |
|
self.train() |
|
return att_ws_dict |
|
|
|
def _integrate_with_spk_embed(self, hs, spembs): |
|
"""Integrate speaker embedding with hidden states. |
|
|
|
Args: |
|
hs (Tensor): Batch of hidden state sequences (B, Tmax, adim). |
|
spembs (Tensor): Batch of speaker embeddings (B, spk_embed_dim). |
|
|
|
Returns: |
|
Tensor: Batch of integrated hidden state sequences (B, Tmax, adim) |
|
|
|
""" |
|
if self.spk_embed_integration_type == "add": |
|
|
|
spembs = self.projection(F.normalize(spembs)) |
|
hs = hs + spembs.unsqueeze(1) |
|
elif self.spk_embed_integration_type == "concat": |
|
|
|
spembs = F.normalize(spembs).unsqueeze(1).expand(-1, hs.size(1), -1) |
|
hs = self.projection(torch.cat([hs, spembs], dim=-1)) |
|
else: |
|
raise NotImplementedError("support only add or concat.") |
|
|
|
return hs |
|
|
|
def _source_mask(self, ilens): |
|
"""Make masks for self-attention. |
|
|
|
Args: |
|
ilens (LongTensor or List): Batch of lengths (B,). |
|
|
|
Returns: |
|
Tensor: Mask tensor for self-attention. |
|
dtype=torch.uint8 in PyTorch 1.2- |
|
dtype=torch.bool in PyTorch 1.2+ (including 1.2) |
|
|
|
Examples: |
|
>>> ilens = [5, 3] |
|
>>> self._source_mask(ilens) |
|
tensor([[[1, 1, 1, 1, 1], |
|
[[1, 1, 1, 0, 0]]], dtype=torch.uint8) |
|
|
|
""" |
|
x_masks = make_non_pad_mask(ilens).to(next(self.parameters()).device) |
|
return x_masks.unsqueeze(-2) |
|
|
|
def _target_mask(self, olens): |
|
"""Make masks for masked self-attention. |
|
|
|
Args: |
|
olens (LongTensor or List): Batch of lengths (B,). |
|
|
|
Returns: |
|
Tensor: Mask tensor for masked self-attention. |
|
dtype=torch.uint8 in PyTorch 1.2- |
|
dtype=torch.bool in PyTorch 1.2+ (including 1.2) |
|
|
|
Examples: |
|
>>> olens = [5, 3] |
|
>>> self._target_mask(olens) |
|
tensor([[[1, 0, 0, 0, 0], |
|
[1, 1, 0, 0, 0], |
|
[1, 1, 1, 0, 0], |
|
[1, 1, 1, 1, 0], |
|
[1, 1, 1, 1, 1]], |
|
[[1, 0, 0, 0, 0], |
|
[1, 1, 0, 0, 0], |
|
[1, 1, 1, 0, 0], |
|
[1, 1, 1, 0, 0], |
|
[1, 1, 1, 0, 0]]], dtype=torch.uint8) |
|
|
|
""" |
|
y_masks = make_non_pad_mask(olens).to(next(self.parameters()).device) |
|
s_masks = subsequent_mask(y_masks.size(-1), device=y_masks.device).unsqueeze(0) |
|
return y_masks.unsqueeze(-2) & s_masks |
|
|
|
@property |
|
def base_plot_keys(self): |
|
"""Return base key names to plot during training. |
|
|
|
keys should match what `chainer.reporter` reports. |
|
If you add the key `loss`, the reporter will report `main/loss` |
|
and `validation/main/loss` values. |
|
also `loss.png` will be created as a figure visulizing `main/loss` |
|
and `validation/main/loss` values. |
|
|
|
Returns: |
|
list: List of strings which are base keys to plot during training. |
|
|
|
""" |
|
plot_keys = ["loss", "l1_loss", "l2_loss", "bce_loss"] |
|
if self.use_scaled_pos_enc: |
|
plot_keys += ["encoder_alpha", "decoder_alpha"] |
|
if self.use_guided_attn_loss: |
|
if "encoder" in self.modules_applied_guided_attn: |
|
plot_keys += ["enc_attn_loss"] |
|
if "decoder" in self.modules_applied_guided_attn: |
|
plot_keys += ["dec_attn_loss"] |
|
if "encoder-decoder" in self.modules_applied_guided_attn: |
|
plot_keys += ["enc_dec_attn_loss"] |
|
|
|
return plot_keys |
|
|