|
import torch |
|
from modules.portaspeech.portaspeech_flow import PortaSpeechFlow |
|
from tasks.tts.fs2 import FastSpeech2Task |
|
from tasks.tts.ps import PortaSpeechTask |
|
from utils.pitch_utils import denorm_f0 |
|
from utils.hparams import hparams |
|
|
|
|
|
class PortaSpeechFlowTask(PortaSpeechTask): |
|
def __init__(self): |
|
super().__init__() |
|
self.training_post_glow = False |
|
|
|
def build_tts_model(self): |
|
ph_dict_size = len(self.token_encoder) |
|
word_dict_size = len(self.word_encoder) |
|
self.model = PortaSpeechFlow(ph_dict_size, word_dict_size, hparams) |
|
|
|
def _training_step(self, sample, batch_idx, opt_idx): |
|
self.training_post_glow = self.global_step >= hparams['post_glow_training_start'] \ |
|
and hparams['use_post_flow'] |
|
if hparams['two_stage'] and \ |
|
((opt_idx == 0 and self.training_post_glow) or (opt_idx == 1 and not self.training_post_glow)): |
|
return None |
|
loss_output, _ = self.run_model(sample) |
|
total_loss = sum([v for v in loss_output.values() if isinstance(v, torch.Tensor) and v.requires_grad]) |
|
loss_output['batch_size'] = sample['txt_tokens'].size()[0] |
|
if 'postflow' in loss_output and loss_output['postflow'] is None: |
|
return None |
|
return total_loss, loss_output |
|
|
|
def run_model(self, sample, infer=False, *args, **kwargs): |
|
if not infer: |
|
training_post_glow = self.training_post_glow |
|
spk_embed = sample.get('spk_embed') |
|
spk_id = sample.get('spk_ids') |
|
output = self.model(sample['txt_tokens'], |
|
sample['word_tokens'], |
|
ph2word=sample['ph2word'], |
|
mel2word=sample['mel2word'], |
|
mel2ph=sample['mel2ph'], |
|
word_len=sample['word_lengths'].max(), |
|
tgt_mels=sample['mels'], |
|
pitch=sample.get('pitch'), |
|
spk_embed=spk_embed, |
|
spk_id=spk_id, |
|
infer=False, |
|
forward_post_glow=training_post_glow, |
|
two_stage=hparams['two_stage'], |
|
global_step=self.global_step, |
|
bert_feats=sample.get('bert_feats')) |
|
losses = {} |
|
self.add_mel_loss(output['mel_out'], sample['mels'], losses) |
|
if (training_post_glow or not hparams['two_stage']) and hparams['use_post_flow']: |
|
losses['postflow'] = output['postflow'] |
|
losses['l1'] = losses['l1'].detach() |
|
losses['ssim'] = losses['ssim'].detach() |
|
if not training_post_glow or not hparams['two_stage'] or not self.training: |
|
losses['kl'] = output['kl'] |
|
if self.global_step < hparams['kl_start_steps']: |
|
losses['kl'] = losses['kl'].detach() |
|
else: |
|
losses['kl'] = torch.clamp(losses['kl'], min=hparams['kl_min']) |
|
losses['kl'] = losses['kl'] * hparams['lambda_kl'] |
|
if hparams['dur_level'] == 'word': |
|
self.add_dur_loss( |
|
output['dur'], sample['mel2word'], sample['word_lengths'], sample['txt_tokens'], losses) |
|
self.get_attn_stats(output['attn'], sample, losses) |
|
else: |
|
super().add_dur_loss(output['dur'], sample['mel2ph'], sample['txt_tokens'], losses) |
|
return losses, output |
|
else: |
|
use_gt_dur = kwargs.get('infer_use_gt_dur', hparams['use_gt_dur']) |
|
forward_post_glow = self.global_step >= hparams['post_glow_training_start'] + 1000 \ |
|
and hparams['use_post_flow'] |
|
spk_embed = sample.get('spk_embed') |
|
spk_id = sample.get('spk_ids') |
|
output = self.model( |
|
sample['txt_tokens'], |
|
sample['word_tokens'], |
|
ph2word=sample['ph2word'], |
|
word_len=sample['word_lengths'].max(), |
|
pitch=sample.get('pitch'), |
|
mel2ph=sample['mel2ph'] if use_gt_dur else None, |
|
mel2word=sample['mel2word'] if hparams['profile_infer'] or hparams['use_gt_dur'] else None, |
|
infer=True, |
|
forward_post_glow=forward_post_glow, |
|
spk_embed=spk_embed, |
|
spk_id=spk_id, |
|
two_stage=hparams['two_stage'], |
|
bert_feats=sample.get('bert_feats')) |
|
return output |
|
|
|
def validation_step(self, sample, batch_idx): |
|
self.training_post_glow = self.global_step >= hparams['post_glow_training_start'] \ |
|
and hparams['use_post_flow'] |
|
return super().validation_step(sample, batch_idx) |
|
|
|
def save_valid_result(self, sample, batch_idx, model_out): |
|
super(PortaSpeechFlowTask, self).save_valid_result(sample, batch_idx, model_out) |
|
sr = hparams['audio_sample_rate'] |
|
f0_gt = None |
|
if sample.get('f0') is not None: |
|
f0_gt = denorm_f0(sample['f0'][0].cpu(), sample['uv'][0].cpu()) |
|
if self.global_step > 0: |
|
|
|
if hparams['use_post_flow']: |
|
wav_pred = self.vocoder.spec2wav(model_out['mel_out_fvae'][0].cpu(), f0=f0_gt) |
|
self.logger.add_audio(f'wav_fvae_{batch_idx}', wav_pred, self.global_step, sr) |
|
self.plot_mel(batch_idx, sample['mels'], model_out['mel_out_fvae'][0], |
|
f'mel_fvae_{batch_idx}', f0s=f0_gt) |
|
|
|
def build_optimizer(self, model): |
|
if hparams['two_stage'] and hparams['use_post_flow']: |
|
self.optimizer = torch.optim.AdamW( |
|
[p for name, p in self.model.named_parameters() if 'post_flow' not in name], |
|
lr=hparams['lr'], |
|
betas=(hparams['optimizer_adam_beta1'], hparams['optimizer_adam_beta2']), |
|
weight_decay=hparams['weight_decay']) |
|
self.post_flow_optimizer = torch.optim.AdamW( |
|
self.model.post_flow.parameters(), |
|
lr=hparams['post_flow_lr'], |
|
betas=(hparams['optimizer_adam_beta1'], hparams['optimizer_adam_beta2']), |
|
weight_decay=hparams['weight_decay']) |
|
return [self.optimizer, self.post_flow_optimizer] |
|
else: |
|
self.optimizer = torch.optim.AdamW( |
|
self.model.parameters(), |
|
lr=hparams['lr'], |
|
betas=(hparams['optimizer_adam_beta1'], hparams['optimizer_adam_beta2']), |
|
weight_decay=hparams['weight_decay']) |
|
return [self.optimizer] |
|
|
|
def build_scheduler(self, optimizer): |
|
return FastSpeech2Task.build_scheduler(self, optimizer[0]) |