File size: 5,864 Bytes
e27a095
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
877f661
e27a095
efe955e
e27a095
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
852815e
 
e27a095
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
#!/usr/bin/python3
# -*- coding: utf-8 -*-
import torch
import torch
import torch.nn.functional as F


def stft(x, fft_size, hop_size, win_length, window):
    """
    Perform STFT and convert to magnitude spectrogram.
    :param x: Tensor, Input signal tensor (B, T).
    :param fft_size: int, FFT size.
    :param hop_size: int, Hop size.
    :param win_length: int, Window length.
    :param window: str, Window function type.
    :return: Magnitude spectrogram (B, #frames, fft_size // 2 + 1).
    """

    x_stft = torch.stft(x, fft_size, hop_size, win_length, window, return_complex=True)

    return x_stft.abs()

class SpectralConvergenceLoss(torch.nn.Module):
    """Spectral convergence loss module."""

    def __init__(self):
        super(SpectralConvergenceLoss, self).__init__()

    def forward(self, x_mag, y_mag):
        """
        Calculate forward propagation.
        :param x_mag: Tensor, Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
        :param y_mag: Tensor, Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
        :return: Tensor, Spectral convergence loss value.
        """
        return torch.norm(y_mag - x_mag, p="fro") / torch.norm(y_mag, p="fro")


class LogSTFTMagnitudeLoss(torch.nn.Module):
    """Log STFT magnitude loss module."""

    def __init__(self):
        super(LogSTFTMagnitudeLoss, self).__init__()

    def forward(self, x_mag, y_mag):
        """
        Calculate forward propagation.
        :param x_mag: Tensor, Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
        :param y_mag: Tensor, Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
        :return: Tensor, Log STFT magnitude loss value.
        """
        y_mag = torch.clamp(y_mag, min=1e-8)
        x_mag = torch.clamp(x_mag, min=1e-8)
        return F.l1_loss(torch.log(y_mag), torch.log(x_mag))


class STFTLoss(torch.nn.Module):
    """STFT loss module."""

    def __init__(
        self, fft_size=1024, shift_size=120, win_length=600, window="hann_window",
        band="full"
    ):
        super(STFTLoss, self).__init__()
        self.fft_size = fft_size
        self.shift_size = shift_size
        self.win_length = win_length
        self.band = band

        self.spectral_convergence_loss = SpectralConvergenceLoss()
        self.log_stft_magnitude_loss = LogSTFTMagnitudeLoss()
        # NOTE(kan-bayashi): Use register_buffer to fix #223
        self.register_buffer("window", getattr(torch, window)(win_length))

    def forward(self, x, y):
        """
        Calculate forward propagation.
        :param x: Tensor, Predicted signal (B, T).
        :param y: Tensor, Groundtruth signal (B, T).
        :return:
        Tensor, Spectral convergence loss value.
        Tensor, Log STFT magnitude loss value.
        """
        x_mag = stft(x, self.fft_size, self.shift_size, self.win_length, self.window)
        y_mag = stft(y, self.fft_size, self.shift_size, self.win_length, self.window)

        if self.band == "high":
            freq_mask_ind = x_mag.shape[1] // 2  # only select high frequency bands
            sc_loss  = self.spectral_convergence_loss(x_mag[:,freq_mask_ind:,:], y_mag[:,freq_mask_ind:,:])
            mag_loss = self.log_stft_magnitude_loss(x_mag[:,freq_mask_ind:,:], y_mag[:,freq_mask_ind:,:])
        elif self.band == "full":
            sc_loss  = self.spectral_convergence_loss(x_mag, y_mag)
            mag_loss = self.log_stft_magnitude_loss(x_mag, y_mag)
        else:
            raise NotImplementedError

        return sc_loss, mag_loss


class MultiResolutionSTFTLoss(torch.nn.Module):
    """Multi resolution STFT loss module."""

    def __init__(self,
                 fft_sizes=None, hop_sizes=None, win_lengths=None,
                 window="hann_window", sc_lambda=0.1, mag_lambda=0.1, band="full",
                 ):
        """
        Initialize Multi resolution STFT loss module.
        :param fft_sizes: list, List of FFT sizes.
        :param hop_sizes: list, List of hop sizes.
        :param win_lengths: list, List of window lengths.
        :param window: str, Window function type.
        :param sc_lambda: float, a balancing factor across different losses.
        :param mag_lambda: float, a balancing factor across different losses.
        :param band: str, high-band or full-band loss
        """
        super(MultiResolutionSTFTLoss, self).__init__()
        fft_sizes = fft_sizes or [1024, 2048, 512]
        hop_sizes = hop_sizes or [120, 240, 50]
        win_lengths = win_lengths or [600, 1200, 240]

        self.sc_lambda = sc_lambda
        self.mag_lambda = mag_lambda

        assert len(fft_sizes) == len(hop_sizes) == len(win_lengths)
        self.stft_losses = torch.nn.ModuleList()
        for fs, ss, wl in zip(fft_sizes, hop_sizes, win_lengths):
            self.stft_losses += [STFTLoss(fs, ss, wl, window, band)]

    def forward(self, x, y):
        """
        Calculate forward propagation.
        :param x: Tensor, Predicted signal (B, T) or (B, #subband, T).
        :param y: Tensor, Groundtruth signal (B, T) or (B, #subband, T).
        :return:
        Tensor, Multi resolution spectral convergence loss value.
        Tensor, Multi resolution log STFT magnitude loss value.
        """
        if len(x.shape) == 3:
            x = x.view(-1, x.size(2))  # (B, C, T) -> (B x C, T)
            y = y.view(-1, y.size(2))  # (B, C, T) -> (B x C, T)
        sc_loss = 0.0
        mag_loss = 0.0
        for f in self.stft_losses:
            sc_l, mag_l = f(x, y)
            sc_loss += sc_l
            mag_loss += mag_l

        sc_loss *= self.sc_lambda
        sc_loss /= len(self.stft_losses)
        mag_loss *= self.mag_lambda
        mag_loss /= len(self.stft_losses)

        return sc_loss, mag_loss


if __name__ == '__main__':
    pass