conex / espnet /nets /pytorch_backend /fastspeech /duration_calculator.py
tobiasc's picture
Initial commit
ad16788
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright 2019 Tomoki Hayashi
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Duration calculator related modules."""
import torch
from espnet.nets.pytorch_backend.e2e_tts_tacotron2 import Tacotron2
from espnet.nets.pytorch_backend.e2e_tts_transformer import Transformer
from espnet.nets.pytorch_backend.nets_utils import pad_list
class DurationCalculator(torch.nn.Module):
"""Duration calculator module for FastSpeech.
Todo:
* Fix the duplicated calculation of diagonal head decision
"""
def __init__(self, teacher_model):
"""Initialize duration calculator module.
Args:
teacher_model (e2e_tts_transformer.Transformer):
Pretrained auto-regressive Transformer.
"""
super(DurationCalculator, self).__init__()
if isinstance(teacher_model, Transformer):
self.register_buffer("diag_head_idx", torch.tensor(-1))
elif isinstance(teacher_model, Tacotron2):
pass
else:
raise ValueError(
"teacher model should be the instance of "
"e2e_tts_transformer.Transformer or e2e_tts_tacotron2.Tacotron2."
)
self.teacher_model = teacher_model
def forward(self, xs, ilens, ys, olens, spembs=None):
"""Calculate forward propagation.
Args:
xs (Tensor): Batch of the padded sequences of character ids (B, Tmax).
ilens (Tensor): Batch of lengths of each input sequence (B,).
ys (Tensor):
Batch of the padded sequence of target features (B, Lmax, odim).
olens (Tensor): Batch of lengths of each output sequence (B,).
spembs (Tensor, optional):
Batch of speaker embedding vectors (B, spk_embed_dim).
Returns:
Tensor: Batch of durations (B, Tmax).
"""
if isinstance(self.teacher_model, Transformer):
att_ws = self._calculate_encoder_decoder_attentions(
xs, ilens, ys, olens, spembs=spembs
)
# TODO(kan-bayashi): fix this issue
# this does not work in multi-gpu case. registered buffer is not saved.
if int(self.diag_head_idx) == -1:
self._init_diagonal_head(att_ws)
att_ws = att_ws[:, self.diag_head_idx]
else:
# NOTE(kan-bayashi): Here we assume that the teacher is tacotron 2
att_ws = self.teacher_model.calculate_all_attentions(
xs, ilens, ys, spembs=spembs, keep_tensor=True
)
durations = [
self._calculate_duration(att_w, ilen, olen)
for att_w, ilen, olen in zip(att_ws, ilens, olens)
]
return pad_list(durations, 0)
@staticmethod
def _calculate_duration(att_w, ilen, olen):
return torch.stack(
[att_w[:olen, :ilen].argmax(-1).eq(i).sum() for i in range(ilen)]
)
def _init_diagonal_head(self, att_ws):
diagonal_scores = att_ws.max(dim=-1)[0].mean(dim=-1).mean(dim=0) # (H * L,)
self.register_buffer("diag_head_idx", diagonal_scores.argmax())
def _calculate_encoder_decoder_attentions(self, xs, ilens, ys, olens, spembs=None):
att_dict = self.teacher_model.calculate_all_attentions(
xs, ilens, ys, olens, spembs=spembs, skip_output=True, keep_tensor=True
)
return torch.cat(
[att_dict[k] for k in att_dict.keys() if "src_attn" in k], dim=1
) # (B, H*L, Lmax, Tmax)