|
import torch |
|
from modules.GenerSpeech.model.glow_modules import Glow |
|
from modules.fastspeech.tts_modules import PitchPredictor |
|
import random |
|
from modules.GenerSpeech.model.prosody_util import ProsodyAligner, LocalStyleAdaptor |
|
from utils.pitch_utils import f0_to_coarse, denorm_f0 |
|
from modules.commons.common_layers import * |
|
import torch.distributions as dist |
|
from utils.hparams import hparams |
|
from modules.GenerSpeech.model.mixstyle import MixStyle |
|
from modules.fastspeech.fs2 import FastSpeech2 |
|
import json |
|
from modules.fastspeech.tts_modules import DEFAULT_MAX_SOURCE_POSITIONS, DEFAULT_MAX_TARGET_POSITIONS |
|
|
|
class GenerSpeech(FastSpeech2): |
|
''' |
|
GenerSpeech: Towards Style Transfer for Generalizable Out-Of-Domain Text-to-Speech |
|
https://arxiv.org/abs/2205.07211 |
|
''' |
|
def __init__(self, dictionary, out_dims=None): |
|
super().__init__(dictionary, out_dims) |
|
|
|
|
|
self.norm = MixStyle(p=0.5, alpha=0.1, eps=1e-6, hidden_size=self.hidden_size) |
|
|
|
|
|
self.emo_embed_proj = Linear(256, self.hidden_size, bias=True) |
|
|
|
|
|
|
|
self.prosody_extractor_utter = LocalStyleAdaptor(self.hidden_size, hparams['nVQ'], self.padding_idx) |
|
self.l1_utter = nn.Linear(self.hidden_size * 2, self.hidden_size) |
|
self.align_utter = ProsodyAligner(num_layers=2) |
|
|
|
|
|
self.prosody_extractor_ph = LocalStyleAdaptor(self.hidden_size, hparams['nVQ'], self.padding_idx) |
|
self.l1_ph = nn.Linear(self.hidden_size * 2, self.hidden_size) |
|
self.align_ph = ProsodyAligner(num_layers=2) |
|
|
|
|
|
self.prosody_extractor_word = LocalStyleAdaptor(self.hidden_size, hparams['nVQ'], self.padding_idx) |
|
self.l1_word = nn.Linear(self.hidden_size * 2, self.hidden_size) |
|
self.align_word = ProsodyAligner(num_layers=2) |
|
|
|
self.pitch_inpainter_predictor = PitchPredictor( |
|
self.hidden_size, n_chans=self.hidden_size, |
|
n_layers=3, dropout_rate=0.1, odim=2, |
|
padding=hparams['ffn_padding'], kernel_size=hparams['predictor_kernel']) |
|
|
|
|
|
self.max_source_positions = DEFAULT_MAX_SOURCE_POSITIONS |
|
self.embed_positions = SinusoidalPositionalEmbedding( |
|
self.hidden_size, self.padding_idx, |
|
init_size=self.max_source_positions + self.padding_idx + 1, |
|
) |
|
|
|
|
|
cond_hs = 80 |
|
if hparams.get('use_txt_cond', True): |
|
cond_hs = cond_hs + hparams['hidden_size'] |
|
|
|
cond_hs = cond_hs + hparams['hidden_size'] * 3 |
|
self.post_flow = Glow( |
|
80, hparams['post_glow_hidden'], hparams['post_glow_kernel_size'], 1, |
|
hparams['post_glow_n_blocks'], hparams['post_glow_n_block_layers'], |
|
n_split=4, n_sqz=2, |
|
gin_channels=cond_hs, |
|
share_cond_layers=hparams['post_share_cond_layers'], |
|
share_wn_layers=hparams['share_wn_layers'], |
|
sigmoid_scale=hparams['sigmoid_scale'] |
|
) |
|
self.prior_dist = dist.Normal(0, 1) |
|
|
|
|
|
def forward(self, txt_tokens, mel2ph=None, ref_mel2ph=None, ref_mel2word=None, spk_embed=None, emo_embed=None, ref_mels=None, |
|
f0=None, uv=None, skip_decoder=False, global_steps=0, infer=False, **kwargs): |
|
ret = {} |
|
encoder_out = self.encoder(txt_tokens) |
|
src_nonpadding = (txt_tokens > 0).float()[:, :, None] |
|
|
|
|
|
spk_embed = self.spk_embed_proj(spk_embed)[:, None, :] |
|
emo_embed = self.emo_embed_proj(emo_embed)[:, None, :] |
|
|
|
|
|
|
|
dur_inp = (encoder_out + spk_embed + emo_embed) * src_nonpadding |
|
mel2ph = self.add_dur(dur_inp, mel2ph, txt_tokens, ret) |
|
tgt_nonpadding = (mel2ph > 0).float()[:, :, None] |
|
decoder_inp = self.expand_states(encoder_out, mel2ph) |
|
decoder_inp = self.norm(decoder_inp, spk_embed + emo_embed) |
|
|
|
|
|
ret['ref_mel2ph'] = ref_mel2ph |
|
ret['ref_mel2word'] = ref_mel2word |
|
prosody_utter_mel = self.get_prosody_utter(decoder_inp, ref_mels, ret, infer, global_steps) |
|
prosody_ph_mel = self.get_prosody_ph(decoder_inp, ref_mels, ret, infer, global_steps) |
|
prosody_word_mel = self.get_prosody_word(decoder_inp, ref_mels, ret, infer, global_steps) |
|
|
|
|
|
pitch_inp_domain_agnostic = decoder_inp * tgt_nonpadding |
|
pitch_inp_domain_specific = (decoder_inp + spk_embed + emo_embed + prosody_utter_mel + prosody_ph_mel + prosody_word_mel) * tgt_nonpadding |
|
predicted_pitch = self.inpaint_pitch(pitch_inp_domain_agnostic, pitch_inp_domain_specific, f0, uv, mel2ph, ret) |
|
|
|
|
|
decoder_inp = decoder_inp + spk_embed + emo_embed + predicted_pitch + prosody_utter_mel + prosody_ph_mel + prosody_word_mel |
|
ret['decoder_inp'] = decoder_inp = decoder_inp * tgt_nonpadding |
|
if skip_decoder: |
|
return ret |
|
ret['mel_out'] = self.run_decoder(decoder_inp, tgt_nonpadding, ret, infer=infer, **kwargs) |
|
|
|
|
|
is_training = self.training |
|
ret['x_mask'] = tgt_nonpadding |
|
ret['spk_embed'] = spk_embed |
|
ret['emo_embed'] = emo_embed |
|
ret['ref_prosody'] = prosody_utter_mel + prosody_ph_mel + prosody_word_mel |
|
self.run_post_glow(ref_mels, infer, is_training, ret) |
|
return ret |
|
|
|
def get_prosody_ph(self, encoder_out, ref_mels, ret, infer=False, global_steps=0): |
|
|
|
if global_steps > hparams['vq_start'] or infer: |
|
prosody_embedding, loss, ppl = self.prosody_extractor_ph(ref_mels, ret['ref_mel2ph'], no_vq=False) |
|
ret['vq_loss_ph'] = loss |
|
ret['ppl_ph'] = ppl |
|
else: |
|
prosody_embedding = self.prosody_extractor_ph(ref_mels, ret['ref_mel2ph'], no_vq=True) |
|
|
|
|
|
positions = self.embed_positions(prosody_embedding[:, :, 0]) |
|
prosody_embedding = self.l1_ph(torch.cat([prosody_embedding, positions], dim=-1)) |
|
|
|
|
|
|
|
src_key_padding_mask = encoder_out[:, :, 0].eq(self.padding_idx).data |
|
prosody_key_padding_mask = prosody_embedding[:, :, 0].eq(self.padding_idx).data |
|
if global_steps < hparams['forcing']: |
|
output, guided_loss, attn_emo = self.align_ph(encoder_out.transpose(0, 1), prosody_embedding.transpose(0, 1), |
|
src_key_padding_mask, prosody_key_padding_mask, forcing=True) |
|
else: |
|
output, guided_loss, attn_emo = self.align_ph(encoder_out.transpose(0, 1), prosody_embedding.transpose(0, 1), |
|
src_key_padding_mask, prosody_key_padding_mask, forcing=False) |
|
|
|
ret['gloss_ph'] = guided_loss |
|
ret['attn_ph'] = attn_emo |
|
return output.transpose(0, 1) |
|
|
|
def get_prosody_word(self, encoder_out, ref_mels, ret, infer=False, global_steps=0): |
|
|
|
if global_steps > hparams['vq_start'] or infer: |
|
prosody_embedding, loss, ppl = self.prosody_extractor_word(ref_mels, ret['ref_mel2word'], no_vq=False) |
|
ret['vq_loss_word'] = loss |
|
ret['ppl_word'] = ppl |
|
else: |
|
prosody_embedding = self.prosody_extractor_word(ref_mels, ret['ref_mel2word'], no_vq=True) |
|
|
|
|
|
positions = self.embed_positions(prosody_embedding[:, :, 0]) |
|
prosody_embedding = self.l1_word(torch.cat([prosody_embedding, positions], dim=-1)) |
|
|
|
|
|
|
|
src_key_padding_mask = encoder_out[:, :, 0].eq(self.padding_idx).data |
|
prosody_key_padding_mask = prosody_embedding[:, :, 0].eq(self.padding_idx).data |
|
if global_steps < hparams['forcing']: |
|
output, guided_loss, attn_emo = self.align_word(encoder_out.transpose(0, 1), prosody_embedding.transpose(0, 1), |
|
src_key_padding_mask, prosody_key_padding_mask, forcing=True) |
|
else: |
|
output, guided_loss, attn_emo = self.align_word(encoder_out.transpose(0, 1), prosody_embedding.transpose(0, 1), |
|
src_key_padding_mask, prosody_key_padding_mask, forcing=False) |
|
ret['gloss_word'] = guided_loss |
|
ret['attn_word'] = attn_emo |
|
return output.transpose(0, 1) |
|
|
|
def get_prosody_utter(self, encoder_out, ref_mels, ret, infer=False, global_steps=0): |
|
|
|
if global_steps > hparams['vq_start'] or infer: |
|
prosody_embedding, loss, ppl = self.prosody_extractor_utter(ref_mels, no_vq=False) |
|
ret['vq_loss_utter'] = loss |
|
ret['ppl_utter'] = ppl |
|
else: |
|
prosody_embedding = self.prosody_extractor_utter(ref_mels, no_vq=True) |
|
|
|
|
|
positions = self.embed_positions(prosody_embedding[:, :, 0]) |
|
prosody_embedding = self.l1_utter(torch.cat([prosody_embedding, positions], dim=-1)) |
|
|
|
|
|
|
|
src_key_padding_mask = encoder_out[:, :, 0].eq(self.padding_idx).data |
|
prosody_key_padding_mask = prosody_embedding[:, :, 0].eq(self.padding_idx).data |
|
if global_steps < hparams['forcing']: |
|
output, guided_loss, attn_emo = self.align_utter(encoder_out.transpose(0, 1), prosody_embedding.transpose(0, 1), |
|
src_key_padding_mask, prosody_key_padding_mask, forcing=True) |
|
else: |
|
output, guided_loss, attn_emo = self.align_utter(encoder_out.transpose(0, 1), prosody_embedding.transpose(0, 1), |
|
src_key_padding_mask, prosody_key_padding_mask, forcing=False) |
|
ret['gloss_utter'] = guided_loss |
|
ret['attn_utter'] = attn_emo |
|
return output.transpose(0, 1) |
|
|
|
|
|
|
|
def inpaint_pitch(self, pitch_inp_domain_agnostic, pitch_inp_domain_specific, f0, uv, mel2ph, ret): |
|
if hparams['pitch_type'] == 'frame': |
|
pitch_padding = mel2ph == 0 |
|
if hparams['predictor_grad'] != 1: |
|
pitch_inp_domain_agnostic = pitch_inp_domain_agnostic.detach() + hparams['predictor_grad'] * (pitch_inp_domain_agnostic - pitch_inp_domain_agnostic.detach()) |
|
pitch_inp_domain_specific = pitch_inp_domain_specific.detach() + hparams['predictor_grad'] * (pitch_inp_domain_specific - pitch_inp_domain_specific.detach()) |
|
|
|
pitch_domain_agnostic = self.pitch_predictor(pitch_inp_domain_agnostic) |
|
pitch_domain_specific = self.pitch_inpainter_predictor(pitch_inp_domain_specific) |
|
pitch_pred = pitch_domain_agnostic + pitch_domain_specific |
|
ret['pitch_pred'] = pitch_pred |
|
|
|
use_uv = hparams['pitch_type'] == 'frame' and hparams['use_uv'] |
|
if f0 is None: |
|
f0 = pitch_pred[:, :, 0] |
|
if use_uv: |
|
uv = pitch_pred[:, :, 1] > 0 |
|
f0_denorm = denorm_f0(f0, uv if use_uv else None, hparams, pitch_padding=pitch_padding) |
|
pitch = f0_to_coarse(f0_denorm) |
|
ret['f0_denorm'] = f0_denorm |
|
ret['f0_denorm_pred'] = denorm_f0(pitch_pred[:, :, 0], (pitch_pred[:, :, 1] > 0) if use_uv else None, hparams, pitch_padding=pitch_padding) |
|
if hparams['pitch_type'] == 'ph': |
|
pitch = torch.gather(F.pad(pitch, [1, 0]), 1, mel2ph) |
|
ret['f0_denorm'] = torch.gather(F.pad(ret['f0_denorm'], [1, 0]), 1, mel2ph) |
|
ret['f0_denorm_pred'] = torch.gather(F.pad(ret['f0_denorm_pred'], [1, 0]), 1, mel2ph) |
|
pitch_embed = self.pitch_embed(pitch) |
|
return pitch_embed |
|
|
|
def run_post_glow(self, tgt_mels, infer, is_training, ret): |
|
x_recon = ret['mel_out'].transpose(1, 2) |
|
g = x_recon |
|
B, _, T = g.shape |
|
if hparams.get('use_txt_cond', True): |
|
g = torch.cat([g, ret['decoder_inp'].transpose(1, 2)], 1) |
|
g_spk_embed = ret['spk_embed'].repeat(1, T, 1).transpose(1, 2) |
|
g_emo_embed = ret['emo_embed'].repeat(1, T, 1).transpose(1, 2) |
|
l_ref_prosody = ret['ref_prosody'].transpose(1, 2) |
|
g = torch.cat([g, g_spk_embed, g_emo_embed, l_ref_prosody], dim=1) |
|
prior_dist = self.prior_dist |
|
if not infer: |
|
if is_training: |
|
self.train() |
|
x_mask = ret['x_mask'].transpose(1, 2) |
|
y_lengths = x_mask.sum(-1) |
|
g = g.detach() |
|
tgt_mels = tgt_mels.transpose(1, 2) |
|
z_postflow, ldj = self.post_flow(tgt_mels, x_mask, g=g) |
|
ldj = ldj / y_lengths / 80 |
|
ret['z_pf'], ret['ldj_pf'] = z_postflow, ldj |
|
ret['postflow'] = -prior_dist.log_prob(z_postflow).mean() - ldj.mean() |
|
else: |
|
x_mask = torch.ones_like(x_recon[:, :1, :]) |
|
z_post = prior_dist.sample(x_recon.shape).to(g.device) * hparams['noise_scale'] |
|
x_recon_, _ = self.post_flow(z_post, x_mask, g, reverse=True) |
|
x_recon = x_recon_ |
|
ret['mel_out'] = x_recon.transpose(1, 2) |