|
|
|
|
|
|
|
"""FastSpeech related modules.""" |
|
|
|
import logging |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
|
|
from espnet.asr.asr_utils import get_model_conf |
|
from espnet.asr.asr_utils import torch_load |
|
from espnet.nets.pytorch_backend.fastspeech.duration_calculator import ( |
|
DurationCalculator, |
|
) |
|
from espnet.nets.pytorch_backend.fastspeech.duration_predictor import DurationPredictor |
|
from espnet.nets.pytorch_backend.fastspeech.duration_predictor import ( |
|
DurationPredictorLoss, |
|
) |
|
from espnet.nets.pytorch_backend.fastspeech.length_regulator import LengthRegulator |
|
from espnet.nets.pytorch_backend.nets_utils import make_non_pad_mask |
|
from espnet.nets.pytorch_backend.nets_utils import make_pad_mask |
|
from espnet.nets.pytorch_backend.tacotron2.decoder import Postnet |
|
from espnet.nets.pytorch_backend.transformer.attention import MultiHeadedAttention |
|
from espnet.nets.pytorch_backend.transformer.embedding import PositionalEncoding |
|
from espnet.nets.pytorch_backend.transformer.embedding import ScaledPositionalEncoding |
|
from espnet.nets.pytorch_backend.transformer.encoder import Encoder |
|
from espnet.nets.pytorch_backend.transformer.initializer import initialize |
|
from espnet.nets.tts_interface import TTSInterface |
|
from espnet.utils.cli_utils import strtobool |
|
from espnet.utils.fill_missing_args import fill_missing_args |
|
|
|
|
|
class FeedForwardTransformerLoss(torch.nn.Module): |
|
"""Loss function module for feed-forward Transformer.""" |
|
|
|
def __init__(self, use_masking=True, use_weighted_masking=False): |
|
"""Initialize feed-forward Transformer loss module. |
|
|
|
Args: |
|
use_masking (bool): |
|
Whether to apply masking for padded part in loss calculation. |
|
use_weighted_masking (bool): |
|
Whether to weighted masking in loss calculation. |
|
|
|
""" |
|
super(FeedForwardTransformerLoss, self).__init__() |
|
assert (use_masking != use_weighted_masking) or not use_masking |
|
self.use_masking = use_masking |
|
self.use_weighted_masking = use_weighted_masking |
|
|
|
|
|
reduction = "none" if self.use_weighted_masking else "mean" |
|
self.l1_criterion = torch.nn.L1Loss(reduction=reduction) |
|
self.duration_criterion = DurationPredictorLoss(reduction=reduction) |
|
|
|
def forward(self, after_outs, before_outs, d_outs, ys, ds, ilens, olens): |
|
"""Calculate forward propagation. |
|
|
|
Args: |
|
after_outs (Tensor): Batch of outputs after postnets (B, Lmax, odim). |
|
before_outs (Tensor): Batch of outputs before postnets (B, Lmax, odim). |
|
d_outs (Tensor): Batch of outputs of duration predictor (B, Tmax). |
|
ys (Tensor): Batch of target features (B, Lmax, odim). |
|
ds (Tensor): Batch of durations (B, Tmax). |
|
ilens (LongTensor): Batch of the lengths of each input (B,). |
|
olens (LongTensor): Batch of the lengths of each target (B,). |
|
|
|
Returns: |
|
Tensor: L1 loss value. |
|
Tensor: Duration predictor loss value. |
|
|
|
""" |
|
|
|
if self.use_masking: |
|
duration_masks = make_non_pad_mask(ilens).to(ys.device) |
|
d_outs = d_outs.masked_select(duration_masks) |
|
ds = ds.masked_select(duration_masks) |
|
out_masks = make_non_pad_mask(olens).unsqueeze(-1).to(ys.device) |
|
before_outs = before_outs.masked_select(out_masks) |
|
after_outs = ( |
|
after_outs.masked_select(out_masks) if after_outs is not None else None |
|
) |
|
ys = ys.masked_select(out_masks) |
|
|
|
|
|
l1_loss = self.l1_criterion(before_outs, ys) |
|
if after_outs is not None: |
|
l1_loss += self.l1_criterion(after_outs, ys) |
|
duration_loss = self.duration_criterion(d_outs, ds) |
|
|
|
|
|
if self.use_weighted_masking: |
|
out_masks = make_non_pad_mask(olens).unsqueeze(-1).to(ys.device) |
|
out_weights = out_masks.float() / out_masks.sum(dim=1, keepdim=True).float() |
|
out_weights /= ys.size(0) * ys.size(2) |
|
duration_masks = make_non_pad_mask(ilens).to(ys.device) |
|
duration_weights = ( |
|
duration_masks.float() / duration_masks.sum(dim=1, keepdim=True).float() |
|
) |
|
duration_weights /= ds.size(0) |
|
|
|
|
|
l1_loss = l1_loss.mul(out_weights).masked_select(out_masks).sum() |
|
duration_loss = ( |
|
duration_loss.mul(duration_weights).masked_select(duration_masks).sum() |
|
) |
|
|
|
return l1_loss, duration_loss |
|
|
|
|
|
class FeedForwardTransformer(TTSInterface, torch.nn.Module): |
|
"""Feed Forward Transformer for TTS a.k.a. FastSpeech. |
|
|
|
This is a module of FastSpeech, |
|
feed-forward Transformer with duration predictor described in |
|
`FastSpeech: Fast, Robust and Controllable Text to Speech`_, |
|
which does not require any auto-regressive |
|
processing during inference, |
|
resulting in fast decoding compared with auto-regressive Transformer. |
|
|
|
.. _`FastSpeech: Fast, Robust and Controllable Text to Speech`: |
|
https://arxiv.org/pdf/1905.09263.pdf |
|
|
|
""" |
|
|
|
@staticmethod |
|
def add_arguments(parser): |
|
"""Add model-specific arguments to the parser.""" |
|
group = parser.add_argument_group("feed-forward transformer model setting") |
|
|
|
group.add_argument( |
|
"--adim", |
|
default=384, |
|
type=int, |
|
help="Number of attention transformation dimensions", |
|
) |
|
group.add_argument( |
|
"--aheads", |
|
default=4, |
|
type=int, |
|
help="Number of heads for multi head attention", |
|
) |
|
group.add_argument( |
|
"--elayers", default=6, type=int, help="Number of encoder layers" |
|
) |
|
group.add_argument( |
|
"--eunits", default=1536, type=int, help="Number of encoder hidden units" |
|
) |
|
group.add_argument( |
|
"--dlayers", default=6, type=int, help="Number of decoder layers" |
|
) |
|
group.add_argument( |
|
"--dunits", default=1536, type=int, help="Number of decoder hidden units" |
|
) |
|
group.add_argument( |
|
"--positionwise-layer-type", |
|
default="linear", |
|
type=str, |
|
choices=["linear", "conv1d", "conv1d-linear"], |
|
help="Positionwise layer type.", |
|
) |
|
group.add_argument( |
|
"--positionwise-conv-kernel-size", |
|
default=3, |
|
type=int, |
|
help="Kernel size of positionwise conv1d layer", |
|
) |
|
group.add_argument( |
|
"--postnet-layers", default=0, type=int, help="Number of postnet layers" |
|
) |
|
group.add_argument( |
|
"--postnet-chans", default=256, type=int, help="Number of postnet channels" |
|
) |
|
group.add_argument( |
|
"--postnet-filts", default=5, type=int, help="Filter size of postnet" |
|
) |
|
group.add_argument( |
|
"--use-batch-norm", |
|
default=True, |
|
type=strtobool, |
|
help="Whether to use batch normalization", |
|
) |
|
group.add_argument( |
|
"--use-scaled-pos-enc", |
|
default=True, |
|
type=strtobool, |
|
help="Use trainable scaled positional encoding " |
|
"instead of the fixed scale one", |
|
) |
|
group.add_argument( |
|
"--encoder-normalize-before", |
|
default=False, |
|
type=strtobool, |
|
help="Whether to apply layer norm before encoder block", |
|
) |
|
group.add_argument( |
|
"--decoder-normalize-before", |
|
default=False, |
|
type=strtobool, |
|
help="Whether to apply layer norm before decoder block", |
|
) |
|
group.add_argument( |
|
"--encoder-concat-after", |
|
default=False, |
|
type=strtobool, |
|
help="Whether to concatenate attention layer's input and output in encoder", |
|
) |
|
group.add_argument( |
|
"--decoder-concat-after", |
|
default=False, |
|
type=strtobool, |
|
help="Whether to concatenate attention layer's input and output in decoder", |
|
) |
|
group.add_argument( |
|
"--duration-predictor-layers", |
|
default=2, |
|
type=int, |
|
help="Number of layers in duration predictor", |
|
) |
|
group.add_argument( |
|
"--duration-predictor-chans", |
|
default=384, |
|
type=int, |
|
help="Number of channels in duration predictor", |
|
) |
|
group.add_argument( |
|
"--duration-predictor-kernel-size", |
|
default=3, |
|
type=int, |
|
help="Kernel size in duration predictor", |
|
) |
|
group.add_argument( |
|
"--teacher-model", |
|
default=None, |
|
type=str, |
|
nargs="?", |
|
help="Teacher model file path", |
|
) |
|
group.add_argument( |
|
"--reduction-factor", default=1, type=int, help="Reduction factor" |
|
) |
|
group.add_argument( |
|
"--spk-embed-dim", |
|
default=None, |
|
type=int, |
|
help="Number of speaker embedding dimensions", |
|
) |
|
group.add_argument( |
|
"--spk-embed-integration-type", |
|
type=str, |
|
default="add", |
|
choices=["add", "concat"], |
|
help="How to integrate speaker embedding", |
|
) |
|
|
|
group.add_argument( |
|
"--transformer-init", |
|
type=str, |
|
default="pytorch", |
|
choices=[ |
|
"pytorch", |
|
"xavier_uniform", |
|
"xavier_normal", |
|
"kaiming_uniform", |
|
"kaiming_normal", |
|
], |
|
help="How to initialize transformer parameters", |
|
) |
|
group.add_argument( |
|
"--initial-encoder-alpha", |
|
type=float, |
|
default=1.0, |
|
help="Initial alpha value in encoder's ScaledPositionalEncoding", |
|
) |
|
group.add_argument( |
|
"--initial-decoder-alpha", |
|
type=float, |
|
default=1.0, |
|
help="Initial alpha value in decoder's ScaledPositionalEncoding", |
|
) |
|
group.add_argument( |
|
"--transformer-lr", |
|
default=1.0, |
|
type=float, |
|
help="Initial value of learning rate", |
|
) |
|
group.add_argument( |
|
"--transformer-warmup-steps", |
|
default=4000, |
|
type=int, |
|
help="Optimizer warmup steps", |
|
) |
|
group.add_argument( |
|
"--transformer-enc-dropout-rate", |
|
default=0.1, |
|
type=float, |
|
help="Dropout rate for transformer encoder except for attention", |
|
) |
|
group.add_argument( |
|
"--transformer-enc-positional-dropout-rate", |
|
default=0.1, |
|
type=float, |
|
help="Dropout rate for transformer encoder positional encoding", |
|
) |
|
group.add_argument( |
|
"--transformer-enc-attn-dropout-rate", |
|
default=0.1, |
|
type=float, |
|
help="Dropout rate for transformer encoder self-attention", |
|
) |
|
group.add_argument( |
|
"--transformer-dec-dropout-rate", |
|
default=0.1, |
|
type=float, |
|
help="Dropout rate for transformer decoder except " |
|
"for attention and pos encoding", |
|
) |
|
group.add_argument( |
|
"--transformer-dec-positional-dropout-rate", |
|
default=0.1, |
|
type=float, |
|
help="Dropout rate for transformer decoder positional encoding", |
|
) |
|
group.add_argument( |
|
"--transformer-dec-attn-dropout-rate", |
|
default=0.1, |
|
type=float, |
|
help="Dropout rate for transformer decoder self-attention", |
|
) |
|
group.add_argument( |
|
"--transformer-enc-dec-attn-dropout-rate", |
|
default=0.1, |
|
type=float, |
|
help="Dropout rate for transformer encoder-decoder attention", |
|
) |
|
group.add_argument( |
|
"--duration-predictor-dropout-rate", |
|
default=0.1, |
|
type=float, |
|
help="Dropout rate for duration predictor", |
|
) |
|
group.add_argument( |
|
"--postnet-dropout-rate", |
|
default=0.5, |
|
type=float, |
|
help="Dropout rate in postnet", |
|
) |
|
group.add_argument( |
|
"--transfer-encoder-from-teacher", |
|
default=True, |
|
type=strtobool, |
|
help="Whether to transfer teacher's parameters", |
|
) |
|
group.add_argument( |
|
"--transferred-encoder-module", |
|
default="all", |
|
type=str, |
|
choices=["all", "embed"], |
|
help="Encoder modeules to be trasferred from teacher", |
|
) |
|
|
|
group.add_argument( |
|
"--use-masking", |
|
default=True, |
|
type=strtobool, |
|
help="Whether to use masking in calculation of loss", |
|
) |
|
group.add_argument( |
|
"--use-weighted-masking", |
|
default=False, |
|
type=strtobool, |
|
help="Whether to use weighted masking in calculation of loss", |
|
) |
|
return parser |
|
|
|
def __init__(self, idim, odim, args=None): |
|
"""Initialize feed-forward Transformer module. |
|
|
|
Args: |
|
idim (int): Dimension of the inputs. |
|
odim (int): Dimension of the outputs. |
|
args (Namespace, optional): |
|
- elayers (int): Number of encoder layers. |
|
- eunits (int): Number of encoder hidden units. |
|
- adim (int): Number of attention transformation dimensions. |
|
- aheads (int): Number of heads for multi head attention. |
|
- dlayers (int): Number of decoder layers. |
|
- dunits (int): Number of decoder hidden units. |
|
- use_scaled_pos_enc (bool): |
|
Whether to use trainable scaled positional encoding. |
|
- encoder_normalize_before (bool): |
|
Whether to perform layer normalization before encoder block. |
|
- decoder_normalize_before (bool): |
|
Whether to perform layer normalization before decoder block. |
|
- encoder_concat_after (bool): Whether to concatenate attention |
|
layer's input and output in encoder. |
|
- decoder_concat_after (bool): Whether to concatenate attention |
|
layer's input and output in decoder. |
|
- duration_predictor_layers (int): Number of duration predictor layers. |
|
- duration_predictor_chans (int): Number of duration predictor channels. |
|
- duration_predictor_kernel_size (int): |
|
Kernel size of duration predictor. |
|
- spk_embed_dim (int): Number of speaker embedding dimensions. |
|
- spk_embed_integration_type: How to integrate speaker embedding. |
|
- teacher_model (str): Teacher auto-regressive transformer model path. |
|
- reduction_factor (int): Reduction factor. |
|
- transformer_init (float): How to initialize transformer parameters. |
|
- transformer_lr (float): Initial value of learning rate. |
|
- transformer_warmup_steps (int): Optimizer warmup steps. |
|
- transformer_enc_dropout_rate (float): |
|
Dropout rate in encoder except attention & positional encoding. |
|
- transformer_enc_positional_dropout_rate (float): |
|
Dropout rate after encoder positional encoding. |
|
- transformer_enc_attn_dropout_rate (float): |
|
Dropout rate in encoder self-attention module. |
|
- transformer_dec_dropout_rate (float): |
|
Dropout rate in decoder except attention & positional encoding. |
|
- transformer_dec_positional_dropout_rate (float): |
|
Dropout rate after decoder positional encoding. |
|
- transformer_dec_attn_dropout_rate (float): |
|
Dropout rate in deocoder self-attention module. |
|
- transformer_enc_dec_attn_dropout_rate (float): |
|
Dropout rate in encoder-deocoder attention module. |
|
- use_masking (bool): |
|
Whether to apply masking for padded part in loss calculation. |
|
- use_weighted_masking (bool): |
|
Whether to apply weighted masking in loss calculation. |
|
- transfer_encoder_from_teacher: |
|
Whether to transfer encoder using teacher encoder parameters. |
|
- transferred_encoder_module: |
|
Encoder module to be initialized using teacher parameters. |
|
|
|
""" |
|
|
|
TTSInterface.__init__(self) |
|
torch.nn.Module.__init__(self) |
|
|
|
|
|
args = fill_missing_args(args, self.add_arguments) |
|
|
|
|
|
self.idim = idim |
|
self.odim = odim |
|
self.reduction_factor = args.reduction_factor |
|
self.use_scaled_pos_enc = args.use_scaled_pos_enc |
|
self.spk_embed_dim = args.spk_embed_dim |
|
if self.spk_embed_dim is not None: |
|
self.spk_embed_integration_type = args.spk_embed_integration_type |
|
|
|
|
|
padding_idx = 0 |
|
|
|
|
|
pos_enc_class = ( |
|
ScaledPositionalEncoding if self.use_scaled_pos_enc else PositionalEncoding |
|
) |
|
|
|
|
|
encoder_input_layer = torch.nn.Embedding( |
|
num_embeddings=idim, embedding_dim=args.adim, padding_idx=padding_idx |
|
) |
|
self.encoder = Encoder( |
|
idim=idim, |
|
attention_dim=args.adim, |
|
attention_heads=args.aheads, |
|
linear_units=args.eunits, |
|
num_blocks=args.elayers, |
|
input_layer=encoder_input_layer, |
|
dropout_rate=args.transformer_enc_dropout_rate, |
|
positional_dropout_rate=args.transformer_enc_positional_dropout_rate, |
|
attention_dropout_rate=args.transformer_enc_attn_dropout_rate, |
|
pos_enc_class=pos_enc_class, |
|
normalize_before=args.encoder_normalize_before, |
|
concat_after=args.encoder_concat_after, |
|
positionwise_layer_type=args.positionwise_layer_type, |
|
positionwise_conv_kernel_size=args.positionwise_conv_kernel_size, |
|
) |
|
|
|
|
|
if self.spk_embed_dim is not None: |
|
if self.spk_embed_integration_type == "add": |
|
self.projection = torch.nn.Linear(self.spk_embed_dim, args.adim) |
|
else: |
|
self.projection = torch.nn.Linear( |
|
args.adim + self.spk_embed_dim, args.adim |
|
) |
|
|
|
|
|
self.duration_predictor = DurationPredictor( |
|
idim=args.adim, |
|
n_layers=args.duration_predictor_layers, |
|
n_chans=args.duration_predictor_chans, |
|
kernel_size=args.duration_predictor_kernel_size, |
|
dropout_rate=args.duration_predictor_dropout_rate, |
|
) |
|
|
|
|
|
self.length_regulator = LengthRegulator() |
|
|
|
|
|
|
|
|
|
self.decoder = Encoder( |
|
idim=0, |
|
attention_dim=args.adim, |
|
attention_heads=args.aheads, |
|
linear_units=args.dunits, |
|
num_blocks=args.dlayers, |
|
input_layer=None, |
|
dropout_rate=args.transformer_dec_dropout_rate, |
|
positional_dropout_rate=args.transformer_dec_positional_dropout_rate, |
|
attention_dropout_rate=args.transformer_dec_attn_dropout_rate, |
|
pos_enc_class=pos_enc_class, |
|
normalize_before=args.decoder_normalize_before, |
|
concat_after=args.decoder_concat_after, |
|
positionwise_layer_type=args.positionwise_layer_type, |
|
positionwise_conv_kernel_size=args.positionwise_conv_kernel_size, |
|
) |
|
|
|
|
|
self.feat_out = torch.nn.Linear(args.adim, odim * args.reduction_factor) |
|
|
|
|
|
self.postnet = ( |
|
None |
|
if args.postnet_layers == 0 |
|
else Postnet( |
|
idim=idim, |
|
odim=odim, |
|
n_layers=args.postnet_layers, |
|
n_chans=args.postnet_chans, |
|
n_filts=args.postnet_filts, |
|
use_batch_norm=args.use_batch_norm, |
|
dropout_rate=args.postnet_dropout_rate, |
|
) |
|
) |
|
|
|
|
|
self._reset_parameters( |
|
init_type=args.transformer_init, |
|
init_enc_alpha=args.initial_encoder_alpha, |
|
init_dec_alpha=args.initial_decoder_alpha, |
|
) |
|
|
|
|
|
if args.teacher_model is not None: |
|
self.teacher = self._load_teacher_model(args.teacher_model) |
|
else: |
|
self.teacher = None |
|
|
|
|
|
if self.teacher is not None: |
|
self.duration_calculator = DurationCalculator(self.teacher) |
|
else: |
|
self.duration_calculator = None |
|
|
|
|
|
if self.teacher is not None and args.transfer_encoder_from_teacher: |
|
self._transfer_from_teacher(args.transferred_encoder_module) |
|
|
|
|
|
self.criterion = FeedForwardTransformerLoss( |
|
use_masking=args.use_masking, use_weighted_masking=args.use_weighted_masking |
|
) |
|
|
|
def _forward( |
|
self, |
|
xs, |
|
ilens, |
|
ys=None, |
|
olens=None, |
|
spembs=None, |
|
ds=None, |
|
is_inference=False, |
|
alpha=1.0, |
|
): |
|
|
|
x_masks = self._source_mask(ilens) |
|
hs, _ = self.encoder(xs, x_masks) |
|
|
|
|
|
if self.spk_embed_dim is not None: |
|
hs = self._integrate_with_spk_embed(hs, spembs) |
|
|
|
|
|
d_masks = make_pad_mask(ilens).to(xs.device) |
|
if is_inference: |
|
d_outs = self.duration_predictor.inference(hs, d_masks) |
|
hs = self.length_regulator(hs, d_outs, alpha) |
|
else: |
|
if ds is None: |
|
with torch.no_grad(): |
|
ds = self.duration_calculator( |
|
xs, ilens, ys, olens, spembs |
|
) |
|
d_outs = self.duration_predictor(hs, d_masks) |
|
hs = self.length_regulator(hs, ds) |
|
|
|
|
|
if olens is not None: |
|
if self.reduction_factor > 1: |
|
olens_in = olens.new([olen // self.reduction_factor for olen in olens]) |
|
else: |
|
olens_in = olens |
|
h_masks = self._source_mask(olens_in) |
|
else: |
|
h_masks = None |
|
zs, _ = self.decoder(hs, h_masks) |
|
before_outs = self.feat_out(zs).view( |
|
zs.size(0), -1, self.odim |
|
) |
|
|
|
|
|
if self.postnet is None: |
|
after_outs = before_outs |
|
else: |
|
after_outs = before_outs + self.postnet( |
|
before_outs.transpose(1, 2) |
|
).transpose(1, 2) |
|
|
|
if is_inference: |
|
return before_outs, after_outs, d_outs |
|
else: |
|
return before_outs, after_outs, ds, d_outs |
|
|
|
def forward(self, xs, ilens, ys, olens, spembs=None, extras=None, *args, **kwargs): |
|
"""Calculate forward propagation. |
|
|
|
Args: |
|
xs (Tensor): Batch of padded character ids (B, Tmax). |
|
ilens (LongTensor): Batch of lengths of each input batch (B,). |
|
ys (Tensor): Batch of padded target features (B, Lmax, odim). |
|
olens (LongTensor): Batch of the lengths of each target (B,). |
|
spembs (Tensor, optional): |
|
Batch of speaker embedding vectors (B, spk_embed_dim). |
|
extras (Tensor, optional): Batch of precalculated durations (B, Tmax, 1). |
|
|
|
Returns: |
|
Tensor: Loss value. |
|
|
|
""" |
|
|
|
xs = xs[:, : max(ilens)] |
|
ys = ys[:, : max(olens)] |
|
if extras is not None: |
|
extras = extras[:, : max(ilens)].squeeze(-1) |
|
|
|
|
|
before_outs, after_outs, ds, d_outs = self._forward( |
|
xs, ilens, ys, olens, spembs=spembs, ds=extras, is_inference=False |
|
) |
|
|
|
|
|
if self.reduction_factor > 1: |
|
olens = olens.new([olen - olen % self.reduction_factor for olen in olens]) |
|
max_olen = max(olens) |
|
ys = ys[:, :max_olen] |
|
|
|
|
|
if self.postnet is None: |
|
l1_loss, duration_loss = self.criterion( |
|
None, before_outs, d_outs, ys, ds, ilens, olens |
|
) |
|
else: |
|
l1_loss, duration_loss = self.criterion( |
|
after_outs, before_outs, d_outs, ys, ds, ilens, olens |
|
) |
|
loss = l1_loss + duration_loss |
|
report_keys = [ |
|
{"l1_loss": l1_loss.item()}, |
|
{"duration_loss": duration_loss.item()}, |
|
{"loss": loss.item()}, |
|
] |
|
|
|
|
|
if self.use_scaled_pos_enc: |
|
report_keys += [ |
|
{"encoder_alpha": self.encoder.embed[-1].alpha.data.item()}, |
|
{"decoder_alpha": self.decoder.embed[-1].alpha.data.item()}, |
|
] |
|
self.reporter.report(report_keys) |
|
|
|
return loss |
|
|
|
def calculate_all_attentions( |
|
self, xs, ilens, ys, olens, spembs=None, extras=None, *args, **kwargs |
|
): |
|
"""Calculate all of the attention weights. |
|
|
|
Args: |
|
xs (Tensor): Batch of padded character ids (B, Tmax). |
|
ilens (LongTensor): Batch of lengths of each input batch (B,). |
|
ys (Tensor): Batch of padded target features (B, Lmax, odim). |
|
olens (LongTensor): Batch of the lengths of each target (B,). |
|
spembs (Tensor, optional): |
|
Batch of speaker embedding vectors (B, spk_embed_dim). |
|
extras (Tensor, optional): Batch of precalculated durations (B, Tmax, 1). |
|
|
|
Returns: |
|
dict: Dict of attention weights and outputs. |
|
|
|
""" |
|
with torch.no_grad(): |
|
|
|
xs = xs[:, : max(ilens)] |
|
ys = ys[:, : max(olens)] |
|
if extras is not None: |
|
extras = extras[:, : max(ilens)].squeeze(-1) |
|
|
|
|
|
outs = self._forward( |
|
xs, ilens, ys, olens, spembs=spembs, ds=extras, is_inference=False |
|
)[1] |
|
|
|
att_ws_dict = dict() |
|
for name, m in self.named_modules(): |
|
if isinstance(m, MultiHeadedAttention): |
|
attn = m.attn.cpu().numpy() |
|
if "encoder" in name: |
|
attn = [a[:, :l, :l] for a, l in zip(attn, ilens.tolist())] |
|
elif "decoder" in name: |
|
if "src" in name: |
|
attn = [ |
|
a[:, :ol, :il] |
|
for a, il, ol in zip(attn, ilens.tolist(), olens.tolist()) |
|
] |
|
elif "self" in name: |
|
attn = [a[:, :l, :l] for a, l in zip(attn, olens.tolist())] |
|
else: |
|
logging.warning("unknown attention module: " + name) |
|
else: |
|
logging.warning("unknown attention module: " + name) |
|
att_ws_dict[name] = attn |
|
att_ws_dict["predicted_fbank"] = [ |
|
m[:l].T for m, l in zip(outs.cpu().numpy(), olens.tolist()) |
|
] |
|
|
|
return att_ws_dict |
|
|
|
def inference(self, x, inference_args, spemb=None, *args, **kwargs): |
|
"""Generate the sequence of features given the sequences of characters. |
|
|
|
Args: |
|
x (Tensor): Input sequence of characters (T,). |
|
inference_args (Namespace): Dummy for compatibility. |
|
spemb (Tensor, optional): Speaker embedding vector (spk_embed_dim). |
|
|
|
Returns: |
|
Tensor: Output sequence of features (L, odim). |
|
None: Dummy for compatibility. |
|
None: Dummy for compatibility. |
|
|
|
""" |
|
|
|
ilens = torch.tensor([x.shape[0]], dtype=torch.long, device=x.device) |
|
xs = x.unsqueeze(0) |
|
if spemb is not None: |
|
spembs = spemb.unsqueeze(0) |
|
else: |
|
spembs = None |
|
|
|
|
|
alpha = getattr(inference_args, "fastspeech_alpha", 1.0) |
|
|
|
|
|
_, outs, _ = self._forward( |
|
xs, |
|
ilens, |
|
spembs=spembs, |
|
is_inference=True, |
|
alpha=alpha, |
|
) |
|
|
|
return outs[0], None, None |
|
|
|
def _integrate_with_spk_embed(self, hs, spembs): |
|
"""Integrate speaker embedding with hidden states. |
|
|
|
Args: |
|
hs (Tensor): Batch of hidden state sequences (B, Tmax, adim). |
|
spembs (Tensor): Batch of speaker embeddings (B, spk_embed_dim). |
|
|
|
Returns: |
|
Tensor: Batch of integrated hidden state sequences (B, Tmax, adim) |
|
|
|
""" |
|
if self.spk_embed_integration_type == "add": |
|
|
|
spembs = self.projection(F.normalize(spembs)) |
|
hs = hs + spembs.unsqueeze(1) |
|
elif self.spk_embed_integration_type == "concat": |
|
|
|
spembs = F.normalize(spembs).unsqueeze(1).expand(-1, hs.size(1), -1) |
|
hs = self.projection(torch.cat([hs, spembs], dim=-1)) |
|
else: |
|
raise NotImplementedError("support only add or concat.") |
|
|
|
return hs |
|
|
|
def _source_mask(self, ilens): |
|
"""Make masks for self-attention. |
|
|
|
Args: |
|
ilens (LongTensor or List): Batch of lengths (B,). |
|
|
|
Returns: |
|
Tensor: Mask tensor for self-attention. |
|
dtype=torch.uint8 in PyTorch 1.2- |
|
dtype=torch.bool in PyTorch 1.2+ (including 1.2) |
|
|
|
Examples: |
|
>>> ilens = [5, 3] |
|
>>> self._source_mask(ilens) |
|
tensor([[[1, 1, 1, 1, 1], |
|
[1, 1, 1, 0, 0]]], dtype=torch.uint8) |
|
|
|
""" |
|
x_masks = make_non_pad_mask(ilens).to(next(self.parameters()).device) |
|
return x_masks.unsqueeze(-2) |
|
|
|
def _load_teacher_model(self, model_path): |
|
|
|
idim, odim, args = get_model_conf(model_path) |
|
|
|
|
|
assert idim == self.idim |
|
assert odim == self.odim |
|
assert args.reduction_factor == self.reduction_factor |
|
|
|
|
|
from espnet.utils.dynamic_import import dynamic_import |
|
|
|
model_class = dynamic_import(args.model_module) |
|
model = model_class(idim, odim, args) |
|
torch_load(model_path, model) |
|
|
|
|
|
for p in model.parameters(): |
|
p.requires_grad = False |
|
|
|
return model |
|
|
|
def _reset_parameters(self, init_type, init_enc_alpha=1.0, init_dec_alpha=1.0): |
|
|
|
initialize(self, init_type) |
|
|
|
|
|
if self.use_scaled_pos_enc: |
|
self.encoder.embed[-1].alpha.data = torch.tensor(init_enc_alpha) |
|
self.decoder.embed[-1].alpha.data = torch.tensor(init_dec_alpha) |
|
|
|
def _transfer_from_teacher(self, transferred_encoder_module): |
|
if transferred_encoder_module == "all": |
|
for (n1, p1), (n2, p2) in zip( |
|
self.encoder.named_parameters(), self.teacher.encoder.named_parameters() |
|
): |
|
assert n1 == n2, "It seems that encoder structure is different." |
|
assert p1.shape == p2.shape, "It seems that encoder size is different." |
|
p1.data.copy_(p2.data) |
|
elif transferred_encoder_module == "embed": |
|
student_shape = self.encoder.embed[0].weight.data.shape |
|
teacher_shape = self.teacher.encoder.embed[0].weight.data.shape |
|
assert ( |
|
student_shape == teacher_shape |
|
), "It seems that embed dimension is different." |
|
self.encoder.embed[0].weight.data.copy_( |
|
self.teacher.encoder.embed[0].weight.data |
|
) |
|
else: |
|
raise NotImplementedError("Support only all or embed.") |
|
|
|
@property |
|
def attention_plot_class(self): |
|
"""Return plot class for attention weight plot.""" |
|
|
|
from espnet.nets.pytorch_backend.e2e_tts_transformer import TTSPlot |
|
|
|
return TTSPlot |
|
|
|
@property |
|
def base_plot_keys(self): |
|
"""Return base key names to plot during training. |
|
|
|
keys should match what `chainer.reporter` reports. |
|
If you add the key `loss`, |
|
the reporter will report `main/loss` and `validation/main/loss` values. |
|
also `loss.png` will be created as a figure visulizing `main/loss` |
|
and `validation/main/loss` values. |
|
|
|
Returns: |
|
list: List of strings which are base keys to plot during training. |
|
|
|
""" |
|
plot_keys = ["loss", "l1_loss", "duration_loss"] |
|
if self.use_scaled_pos_enc: |
|
plot_keys += ["encoder_alpha", "decoder_alpha"] |
|
|
|
return plot_keys |
|
|