|
import torch |
|
|
|
import utils |
|
from modules.diff.diffusion import GaussianDiffusion |
|
from modules.diff.net import DiffNet |
|
from tasks.tts.fs2 import FastSpeech2Task |
|
from utils.hparams import hparams |
|
|
|
|
|
DIFF_DECODERS = { |
|
'wavenet': lambda hp: DiffNet(hp['audio_num_mel_bins']), |
|
} |
|
|
|
|
|
class DiffFsTask(FastSpeech2Task): |
|
def build_tts_model(self): |
|
mel_bins = hparams['audio_num_mel_bins'] |
|
self.model = GaussianDiffusion( |
|
phone_encoder=self.phone_encoder, |
|
out_dims=mel_bins, denoise_fn=DIFF_DECODERS[hparams['diff_decoder_type']](hparams), |
|
timesteps=hparams['timesteps'], |
|
loss_type=hparams['diff_loss_type'], |
|
spec_min=hparams['spec_min'], spec_max=hparams['spec_max'], |
|
) |
|
|
|
def run_model(self, model, sample, return_output=False, infer=False): |
|
txt_tokens = sample['txt_tokens'] |
|
target = sample['mels'] |
|
mel2ph = sample['mel2ph'] |
|
f0 = sample['f0'] |
|
uv = sample['uv'] |
|
energy = sample['energy'] |
|
spk_embed = sample.get('spk_embed') if not hparams['use_spk_id'] else sample.get('spk_ids') |
|
if hparams['pitch_type'] == 'cwt': |
|
cwt_spec = sample[f'cwt_spec'] |
|
f0_mean = sample['f0_mean'] |
|
f0_std = sample['f0_std'] |
|
sample['f0_cwt'] = f0 = model.cwt2f0_norm(cwt_spec, f0_mean, f0_std, mel2ph) |
|
|
|
output = model(txt_tokens, mel2ph=mel2ph, spk_embed=spk_embed, |
|
ref_mels=target, f0=f0, uv=uv, energy=energy, infer=infer) |
|
|
|
losses = {} |
|
if 'diff_loss' in output: |
|
losses['mel'] = output['diff_loss'] |
|
self.add_dur_loss(output['dur'], mel2ph, txt_tokens, losses=losses) |
|
if hparams['use_pitch_embed']: |
|
self.add_pitch_loss(output, sample, losses) |
|
if hparams['use_energy_embed']: |
|
self.add_energy_loss(output['energy_pred'], energy, losses) |
|
if not return_output: |
|
return losses |
|
else: |
|
return losses, output |
|
|
|
def _training_step(self, sample, batch_idx, _): |
|
log_outputs = self.run_model(self.model, sample) |
|
total_loss = sum([v for v in log_outputs.values() if isinstance(v, torch.Tensor) and v.requires_grad]) |
|
log_outputs['batch_size'] = sample['txt_tokens'].size()[0] |
|
log_outputs['lr'] = self.scheduler.get_lr()[0] |
|
return total_loss, log_outputs |
|
|
|
def validation_step(self, sample, batch_idx): |
|
outputs = {} |
|
outputs['losses'] = {} |
|
outputs['losses'], model_out = self.run_model(self.model, sample, return_output=True, infer=False) |
|
outputs['total_loss'] = sum(outputs['losses'].values()) |
|
outputs['nsamples'] = sample['nsamples'] |
|
outputs = utils.tensors_to_scalars(outputs) |
|
if batch_idx < hparams['num_valid_plots']: |
|
_, model_out = self.run_model(self.model, sample, return_output=True, infer=True) |
|
self.plot_mel(batch_idx, sample['mels'], model_out['mel_out']) |
|
return outputs |
|
|
|
def build_scheduler(self, optimizer): |
|
return torch.optim.lr_scheduler.StepLR(optimizer, hparams['decay_steps'], gamma=0.5) |
|
|
|
def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx): |
|
if optimizer is None: |
|
return |
|
optimizer.step() |
|
optimizer.zero_grad() |
|
if self.scheduler is not None: |
|
self.scheduler.step(self.global_step // hparams['accumulate_grad_batches']) |
|
|