File size: 832 Bytes
14d1720
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import torch
import torch.nn as nn


class FS2Loss(nn.Module):
    def __init__(self):
        super(FS2Loss, self).__init__()
        self.mse_loss = nn.MSELoss()
        self.mae_loss = nn.L1Loss()

    def forward(self, d_pred, d_truth, mel_pred, mel_postnet, mel_truth, src_mask, mel_mask):
        d_pred = d_pred.masked_select(src_mask)
        d_truth = d_truth.masked_select(src_mask)

        mel_pred = mel_pred.masked_select(mel_mask.unsqueeze(-1))
        mel_postnet = mel_postnet.masked_select(mel_mask.unsqueeze(-1))
        mel_truth = mel_truth.masked_select(mel_mask.unsqueeze(-1))

        mel_loss = self.mse_loss(mel_pred, mel_truth) * 0.1
        mel_postnet_loss = self.mse_loss(mel_postnet, mel_truth)
        d_loss = self.mae_loss(d_pred, d_truth) * 0.01

        return mel_loss, mel_postnet_loss, d_loss