Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
class FS2Loss(nn.Module): | |
def __init__(self): | |
super(FS2Loss, self).__init__() | |
self.mse_loss = nn.MSELoss() | |
self.mae_loss = nn.L1Loss() | |
def forward(self, d_pred, d_truth, mel_pred, mel_postnet, mel_truth, src_mask, mel_mask): | |
d_pred = d_pred.masked_select(src_mask) | |
d_truth = d_truth.masked_select(src_mask) | |
mel_pred = mel_pred.masked_select(mel_mask.unsqueeze(-1)) | |
mel_postnet = mel_postnet.masked_select(mel_mask.unsqueeze(-1)) | |
mel_truth = mel_truth.masked_select(mel_mask.unsqueeze(-1)) | |
mel_loss = self.mse_loss(mel_pred, mel_truth) * 0.1 | |
mel_postnet_loss = self.mse_loss(mel_postnet, mel_truth) | |
d_loss = self.mae_loss(d_pred, d_truth) * 0.01 | |
return mel_loss, mel_postnet_loss, d_loss | |