File size: 3,751 Bytes
9d61c9b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
from typing import List, Tuple

from auraloss.time import ESRLoss, SDSDRLoss, SISDRLoss, SNRLoss
import torch
from torch import Tensor
from torch.nn import Module

from models.config import VocoderBasicConfig, VocoderModelConfig

from .multi_resolution_stft_loss import MultiResolutionSTFTLoss


class UnivnetLoss(Module):
    r"""UnivnetLoss is a PyTorch Module that calculates the generator and discriminator losses for Univnet."""

    def __init__(self):
        r"""Initializes the UnivnetLoss module."""
        super().__init__()

        train_config = VocoderBasicConfig()

        self.stft_lamb = train_config.stft_lamb
        self.model_config = VocoderModelConfig()

        self.stft_criterion = MultiResolutionSTFTLoss(self.model_config.mrd.resolutions)
        self.esr_loss = ESRLoss()
        self.sisdr_loss = SISDRLoss()
        self.snr_loss = SNRLoss()
        self.sdsdr_loss = SDSDRLoss()

    def forward(

        self,

        audio: Tensor,

        fake_audio: Tensor,

        res_fake: List[Tuple[Tensor, Tensor]],

        period_fake: List[Tuple[Tensor, Tensor]],

        res_real: List[Tuple[Tensor, Tensor]],

        period_real: List[Tuple[Tensor, Tensor]],

    ) -> Tuple[
        Tensor,
        Tensor,
        Tensor,
        Tensor,
        Tensor,
        Tensor,
    ]:
        r"""Calculate the losses for the generator and discriminator.



        Args:

            audio (torch.Tensor): The real audio samples.

            fake_audio (torch.Tensor): The generated audio samples.

            res_fake (List[Tuple[Tensor, Tensor]]): The discriminator's output for the fake audio.

            period_fake (List[Tuple[Tensor, Tensor]]): The discriminator's output for the fake audio in the period.

            res_real (List[Tuple[Tensor, Tensor]]): The discriminator's output for the real audio.

            period_real (List[Tuple[Tensor, Tensor]]): The discriminator's output for the real audio in the period.



        Returns:

            tuple: A tuple containing the univnet loss, discriminator loss, STFT loss, score loss, ESR, SISDR, SNR and SDSDR losses.

        """
        # Calculate the STFT loss
        sc_loss, mag_loss = self.stft_criterion(fake_audio.squeeze(1), audio.squeeze(1))
        stft_loss = (sc_loss + mag_loss) * self.stft_lamb

        # Pad the fake audio to match the length of the real audio
        padding = audio.shape[2] - fake_audio.shape[2]
        fake_audio_padded = torch.nn.functional.pad(fake_audio, (0, padding))

        esr_loss = self.esr_loss.forward(fake_audio_padded, audio)
        snr_loss = self.snr_loss.forward(fake_audio_padded, audio)

        # Calculate the score loss
        score_loss = torch.tensor(0.0, device=audio.device)
        for _, score_fake in res_fake + period_fake:
            score_loss += torch.mean(torch.pow(score_fake - 1.0, 2))

        score_loss = score_loss / len(res_fake + period_fake)

        # Calculate the total generator loss
        total_loss_gen = score_loss + stft_loss + esr_loss + snr_loss

        # Calculate the discriminator loss
        total_loss_disc = torch.tensor(0.0, device=audio.device)
        for (_, score_fake), (_, score_real) in zip(
            res_fake + period_fake, res_real + period_real
        ):
            total_loss_disc += torch.mean(torch.pow(score_real - 1.0, 2)) + torch.mean(
                torch.pow(score_fake, 2)
            )

        total_loss_disc = total_loss_disc / len(res_fake + period_fake)

        return (
            total_loss_gen,
            total_loss_disc,
            stft_loss,
            score_loss,
            esr_loss,
            snr_loss,
        )