|
import torch |
|
from torch import nn |
|
from tasks.tts.ps_adv import PortaSpeechAdvTask, FastSpeechTask |
|
from text_to_speech.utils.commons.hparams import hparams |
|
from text_to_speech.utils.nn.seq_utils import group_hidden_by_segs |
|
|
|
|
|
class PortaSpeechAdvMLMTask(PortaSpeechAdvTask): |
|
|
|
def build_scheduler(self, optimizer): |
|
return [ |
|
FastSpeechTask.build_scheduler(self, optimizer[0]), |
|
torch.optim.lr_scheduler.StepLR(optimizer=optimizer[1], |
|
**hparams["discriminator_scheduler_params"]), |
|
] |
|
|
|
def on_before_optimization(self, opt_idx): |
|
if opt_idx in [0, 2]: |
|
nn.utils.clip_grad_norm_(self.dp_params, hparams['clip_grad_norm']) |
|
if self.use_bert: |
|
nn.utils.clip_grad_norm_(self.bert_params, hparams['clip_grad_norm']) |
|
nn.utils.clip_grad_norm_(self.gen_params_except_bert_and_dp, hparams['clip_grad_norm']) |
|
else: |
|
nn.utils.clip_grad_norm_(self.gen_params_except_dp, hparams['clip_grad_norm']) |
|
else: |
|
nn.utils.clip_grad_norm_(self.disc_params, hparams["clip_grad_norm"]) |
|
|
|
def on_after_optimization(self, epoch, batch_idx, optimizer, optimizer_idx): |
|
if self.scheduler is not None: |
|
self.scheduler[0].step(self.global_step // hparams['accumulate_grad_batches']) |
|
self.scheduler[1].step(self.global_step // hparams['accumulate_grad_batches']) |
|
|
|
|
|
def _training_step(self, sample, batch_idx, optimizer_idx): |
|
loss_output = {} |
|
loss_weights = {} |
|
disc_start = self.global_step >= hparams["disc_start_steps"] and hparams['lambda_mel_adv'] > 0 |
|
if optimizer_idx == 0: |
|
|
|
|
|
|
|
loss_output, model_out = self.run_model(sample, infer=False) |
|
self.model_out_gt = self.model_out = \ |
|
{k: v.detach() for k, v in model_out.items() if isinstance(v, torch.Tensor)} |
|
if disc_start: |
|
mel_p = model_out['mel_out'] |
|
if hasattr(self.model, 'out2mel'): |
|
mel_p = self.model.out2mel(mel_p) |
|
o_ = self.mel_disc(mel_p) |
|
p_, pc_ = o_['y'], o_['y_c'] |
|
if p_ is not None: |
|
loss_output['a'] = self.mse_loss_fn(p_, p_.new_ones(p_.size())) |
|
loss_weights['a'] = hparams['lambda_mel_adv'] |
|
if pc_ is not None: |
|
loss_output['ac'] = self.mse_loss_fn(pc_, pc_.new_ones(pc_.size())) |
|
loss_weights['ac'] = hparams['lambda_mel_adv'] |
|
else: |
|
return None |
|
|
|
loss_output2, model_out2 = self.run_contrastive_learning(sample) |
|
loss_output.update(loss_output2) |
|
model_out.update(model_out2) |
|
|
|
elif optimizer_idx == 1: |
|
|
|
|
|
|
|
if disc_start and self.global_step % hparams['disc_interval'] == 0: |
|
model_out = self.model_out_gt |
|
mel_g = sample['mels'] |
|
mel_p = model_out['mel_out'] |
|
o = self.mel_disc(mel_g) |
|
p, pc = o['y'], o['y_c'] |
|
o_ = self.mel_disc(mel_p) |
|
p_, pc_ = o_['y'], o_['y_c'] |
|
if p_ is not None: |
|
loss_output["r"] = self.mse_loss_fn(p, p.new_ones(p.size())) |
|
loss_output["f"] = self.mse_loss_fn(p_, p_.new_zeros(p_.size())) |
|
if pc_ is not None: |
|
loss_output["rc"] = self.mse_loss_fn(pc, pc.new_ones(pc.size())) |
|
loss_output["fc"] = self.mse_loss_fn(pc_, pc_.new_zeros(pc_.size())) |
|
|
|
total_loss = sum([loss_weights.get(k, 1) * v for k, v in loss_output.items() if isinstance(v, torch.Tensor) and v.requires_grad]) |
|
loss_output['batch_size'] = sample['txt_tokens'].size()[0] |
|
return total_loss, loss_output |
|
|
|
def run_contrastive_learning(self, sample): |
|
losses = {} |
|
outputs = {} |
|
|
|
bert = self.model.encoder.bert.bert |
|
bert_for_mlm = self.model.encoder.bert |
|
pooler = self.model.encoder.pooler |
|
sim = self.model.encoder.sim |
|
tokenizer = self.model.encoder.tokenizer |
|
ph_encoder = self.model.encoder |
|
|
|
if hparams['lambda_cl'] > 0: |
|
if hparams.get("cl_version", "v1") == "v1": |
|
cl_feats = sample['cl_feats'] |
|
bs, _, t = cl_feats['cl_input_ids'].shape |
|
cl_input_ids = cl_feats['cl_input_ids'].reshape([bs*2, t]) |
|
cl_attention_mask = cl_feats['cl_attention_mask'].reshape([bs*2, t]) |
|
cl_token_type_ids = cl_feats['cl_token_type_ids'].reshape([bs*2, t]) |
|
cl_output = bert(cl_input_ids, attention_mask=cl_attention_mask,token_type_ids=cl_token_type_ids,) |
|
pooler_output = pooler(cl_attention_mask, cl_output) |
|
pooler_output = pooler_output.reshape([bs, 2, -1]) |
|
z1, z2 = pooler_output[:,0], pooler_output[:,1] |
|
|
|
cos_sim = sim(z1.unsqueeze(1), z2.unsqueeze(0)) |
|
labels = torch.arange(cos_sim.size(0)).long().to(z1.device) |
|
ce_fn = nn.CrossEntropyLoss() |
|
cl_loss = ce_fn(cos_sim, labels) |
|
losses['cl_v'] = cl_loss.detach() |
|
losses['cl'] = cl_loss * hparams['lambda_cl'] |
|
elif hparams['cl_version'] == "v2": |
|
|
|
cl_feats = sample['cl_feats'] |
|
bs, _, t = cl_feats['cl_input_ids'].shape |
|
cl_input_ids = cl_feats['cl_input_ids'].reshape([bs*2, t]) |
|
cl_attention_mask = cl_feats['cl_attention_mask'].reshape([bs*2, t]) |
|
cl_token_type_ids = cl_feats['cl_token_type_ids'].reshape([bs*2, t]) |
|
txt_tokens = sample['txt_tokens'] |
|
bert_feats = sample['bert_feats'] |
|
src_nonpadding = (txt_tokens > 0).float()[:, :, None] |
|
ph_encoder_out1 = ph_encoder(txt_tokens, bert_feats=bert_feats, ph2word=sample['ph2word']) * src_nonpadding |
|
ph_encoder_out2 = ph_encoder(txt_tokens, bert_feats=bert_feats, ph2word=sample['ph2word']) * src_nonpadding |
|
|
|
|
|
z1 = ((ph_encoder_out1 * src_nonpadding).sum(1) / src_nonpadding.sum(1)) |
|
z2 = ((ph_encoder_out2 * src_nonpadding).sum(1) / src_nonpadding.sum(1)) |
|
|
|
cos_sim = sim(z1.unsqueeze(1), z2.unsqueeze(0)) |
|
labels = torch.arange(cos_sim.size(0)).long().to(z1.device) |
|
ce_fn = nn.CrossEntropyLoss() |
|
cl_loss = ce_fn(cos_sim, labels) |
|
losses['cl_v'] = cl_loss.detach() |
|
losses['cl'] = cl_loss * hparams['lambda_cl'] |
|
elif hparams['cl_version'] == "v3": |
|
|
|
cl_feats = sample['cl_feats'] |
|
bs, _, t = cl_feats['cl_input_ids'].shape |
|
cl_input_ids = cl_feats['cl_input_ids'].reshape([bs*2, t]) |
|
cl_attention_mask = cl_feats['cl_attention_mask'].reshape([bs*2, t]) |
|
cl_token_type_ids = cl_feats['cl_token_type_ids'].reshape([bs*2, t]) |
|
cl_output = bert(cl_input_ids, attention_mask=cl_attention_mask,token_type_ids=cl_token_type_ids,) |
|
cl_output = cl_output.last_hidden_state.reshape([-1, 768]) |
|
cl_word_out = cl_output[cl_attention_mask.reshape([-1]).bool()] |
|
cl_word_out = cl_word_out.view([-1, 2, 768]) |
|
z1_total, z2_total = cl_word_out[:,0], cl_word_out[:,1] |
|
ce_fn = nn.CrossEntropyLoss() |
|
start_idx = 0 |
|
lengths = cl_attention_mask.sum(-1) |
|
cl_loss_accu = 0 |
|
for i in range(bs): |
|
length = lengths[i] |
|
z1 = z1_total[start_idx:start_idx + length] |
|
z2 = z2_total[start_idx:start_idx + length] |
|
start_idx += length |
|
cos_sim = sim(z1.unsqueeze(1), z2.unsqueeze(0)) |
|
labels = torch.arange(cos_sim.size(0)).long().to(z1.device) |
|
cl_loss_accu += ce_fn(cos_sim, labels) * length |
|
cl_loss = cl_loss_accu / lengths.sum() |
|
losses['cl_v'] = cl_loss.detach() |
|
losses['cl'] = cl_loss * hparams['lambda_cl'] |
|
elif hparams['cl_version'] == "v4": |
|
|
|
cl_feats = sample['cl_feats'] |
|
bs, _, t = cl_feats['cl_input_ids'].shape |
|
cl_input_ids = cl_feats['cl_input_ids'].reshape([bs*2, t]) |
|
cl_attention_mask = cl_feats['cl_attention_mask'].reshape([bs*2, t]) |
|
cl_token_type_ids = cl_feats['cl_token_type_ids'].reshape([bs*2, t]) |
|
cl_output = bert(cl_input_ids, attention_mask=cl_attention_mask,token_type_ids=cl_token_type_ids,) |
|
pooler_output = pooler(cl_attention_mask, cl_output) |
|
pooler_output = pooler_output.reshape([bs, 2, -1]) |
|
z1, z2 = pooler_output[:,0], pooler_output[:,1] |
|
|
|
cos_sim = sim(z1.unsqueeze(1), z2.unsqueeze(0)) |
|
labels = torch.arange(cos_sim.size(0)).long().to(z1.device) |
|
ce_fn = nn.CrossEntropyLoss() |
|
cl_loss = ce_fn(cos_sim, labels) |
|
losses['cl_v'] = cl_loss.detach() |
|
losses['cl'] = cl_loss * hparams['lambda_cl'] |
|
elif hparams['cl_version'] == "v5": |
|
|
|
cl_feats = sample['cl_feats'] |
|
cl_input_ids = cl_feats['sent0']['cl_input_ids'] |
|
cl_attention_mask = cl_feats['sent0']['cl_attention_mask'] |
|
cl_token_type_ids = cl_feats['sent0']['cl_token_type_ids'] |
|
cl_output = bert(cl_input_ids, attention_mask=cl_attention_mask,token_type_ids=cl_token_type_ids,) |
|
z1 = pooler_output_sent0 = pooler(cl_attention_mask, cl_output) |
|
|
|
cl_input_ids = cl_feats['sent1']['cl_input_ids'] |
|
cl_attention_mask = cl_feats['sent1']['cl_attention_mask'] |
|
cl_token_type_ids = cl_feats['sent1']['cl_token_type_ids'] |
|
cl_output = bert(cl_input_ids, attention_mask=cl_attention_mask,token_type_ids=cl_token_type_ids,) |
|
z2 = pooler_output_sent1 = pooler(cl_attention_mask, cl_output) |
|
|
|
cl_input_ids = cl_feats['hard_neg']['cl_input_ids'] |
|
cl_attention_mask = cl_feats['hard_neg']['cl_attention_mask'] |
|
cl_token_type_ids = cl_feats['hard_neg']['cl_token_type_ids'] |
|
cl_output = bert(cl_input_ids, attention_mask=cl_attention_mask,token_type_ids=cl_token_type_ids,) |
|
z3 = pooler_output_neg = pooler(cl_attention_mask, cl_output) |
|
|
|
cos_sim = sim(z1.unsqueeze(1), z2.unsqueeze(0)) |
|
z1_z3_cos = sim(z1.unsqueeze(1), z3.unsqueeze(0)) |
|
cos_sim = torch.cat([cos_sim, z1_z3_cos], 1) |
|
labels = torch.arange(cos_sim.size(0)).long().to(cos_sim.device) |
|
ce_fn = nn.CrossEntropyLoss() |
|
cl_loss = ce_fn(cos_sim, labels) |
|
losses['cl_v'] = cl_loss.detach() |
|
losses['cl'] = cl_loss * hparams['lambda_cl'] |
|
else: |
|
raise NotImplementedError() |
|
|
|
if hparams['lambda_mlm'] > 0: |
|
cl_feats = sample['cl_feats'] |
|
mlm_input_ids = cl_feats['mlm_input_ids'] |
|
bs, t = mlm_input_ids.shape |
|
mlm_input_ids = mlm_input_ids.view((-1, mlm_input_ids.size(-1))) |
|
mlm_labels = cl_feats['mlm_labels'] |
|
mlm_labels = mlm_labels.view(-1, mlm_labels.size(-1)) |
|
mlm_attention_mask = cl_feats['mlm_attention_mask'] |
|
|
|
prediction_scores = bert_for_mlm(mlm_input_ids, mlm_attention_mask).logits |
|
ce_fn = nn.CrossEntropyLoss(reduction="none") |
|
mlm_loss = ce_fn(prediction_scores.view(-1, tokenizer.vocab_size), mlm_labels.view(-1)) |
|
mlm_loss = mlm_loss[mlm_labels.view(-1)>=0].mean() |
|
losses['mlm'] = mlm_loss * hparams['lambda_mlm'] |
|
losses['mlm_v'] = mlm_loss.detach() |
|
|
|
return losses, outputs |
|
|