from typing import List, Tuple from auraloss.time import ESRLoss, SDSDRLoss, SISDRLoss, SNRLoss import torch from torch import Tensor from torch.nn import Module from models.config import VocoderBasicConfig, VocoderModelConfig from .multi_resolution_stft_loss import MultiResolutionSTFTLoss class UnivnetLoss(Module): r"""UnivnetLoss is a PyTorch Module that calculates the generator and discriminator losses for Univnet.""" def __init__(self): r"""Initializes the UnivnetLoss module.""" super().__init__() train_config = VocoderBasicConfig() self.stft_lamb = train_config.stft_lamb self.model_config = VocoderModelConfig() self.stft_criterion = MultiResolutionSTFTLoss(self.model_config.mrd.resolutions) self.esr_loss = ESRLoss() self.sisdr_loss = SISDRLoss() self.snr_loss = SNRLoss() self.sdsdr_loss = SDSDRLoss() def forward( self, audio: Tensor, fake_audio: Tensor, res_fake: List[Tuple[Tensor, Tensor]], period_fake: List[Tuple[Tensor, Tensor]], res_real: List[Tuple[Tensor, Tensor]], period_real: List[Tuple[Tensor, Tensor]], ) -> Tuple[ Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, ]: r"""Calculate the losses for the generator and discriminator. Args: audio (torch.Tensor): The real audio samples. fake_audio (torch.Tensor): The generated audio samples. res_fake (List[Tuple[Tensor, Tensor]]): The discriminator's output for the fake audio. period_fake (List[Tuple[Tensor, Tensor]]): The discriminator's output for the fake audio in the period. res_real (List[Tuple[Tensor, Tensor]]): The discriminator's output for the real audio. period_real (List[Tuple[Tensor, Tensor]]): The discriminator's output for the real audio in the period. Returns: tuple: A tuple containing the univnet loss, discriminator loss, STFT loss, score loss, ESR, SISDR, SNR and SDSDR losses. """ # Calculate the STFT loss sc_loss, mag_loss = self.stft_criterion(fake_audio.squeeze(1), audio.squeeze(1)) stft_loss = (sc_loss + mag_loss) * self.stft_lamb # Pad the fake audio to match the length of the real audio padding = audio.shape[2] - fake_audio.shape[2] fake_audio_padded = torch.nn.functional.pad(fake_audio, (0, padding)) esr_loss = self.esr_loss.forward(fake_audio_padded, audio) snr_loss = self.snr_loss.forward(fake_audio_padded, audio) # Calculate the score loss score_loss = torch.tensor(0.0, device=audio.device) for _, score_fake in res_fake + period_fake: score_loss += torch.mean(torch.pow(score_fake - 1.0, 2)) score_loss = score_loss / len(res_fake + period_fake) # Calculate the total generator loss total_loss_gen = score_loss + stft_loss + esr_loss + snr_loss # Calculate the discriminator loss total_loss_disc = torch.tensor(0.0, device=audio.device) for (_, score_fake), (_, score_real) in zip( res_fake + period_fake, res_real + period_real ): total_loss_disc += torch.mean(torch.pow(score_real - 1.0, 2)) + torch.mean( torch.pow(score_fake, 2) ) total_loss_disc = total_loss_disc / len(res_fake + period_fake) return ( total_loss_gen, total_loss_disc, stft_loss, score_loss, esr_loss, snr_loss, )