File size: 12,410 Bytes
f32cd36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
import typing
from typing import List

import torch
import torch.nn.functional as F
from audiotools import AudioSignal
from audiotools import STFTParams
from torch import nn


class L1Loss(nn.L1Loss):
    """L1 Loss between AudioSignals. Defaults

    to comparing ``audio_data``, but any

    attribute of an AudioSignal can be used.



    Parameters

    ----------

    attribute : str, optional

        Attribute of signal to compare, defaults to ``audio_data``.

    weight : float, optional

        Weight of this loss, defaults to 1.0.



    Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/distance.py

    """

    def __init__(self, attribute: str = "audio_data", weight: float = 1.0, **kwargs):
        self.attribute = attribute
        self.weight = weight
        super().__init__(**kwargs)

    def forward(self, x: AudioSignal, y: AudioSignal):
        """

        Parameters

        ----------

        x : AudioSignal

            Estimate AudioSignal

        y : AudioSignal

            Reference AudioSignal



        Returns

        -------

        torch.Tensor

            L1 loss between AudioSignal attributes.

        """
        if isinstance(x, AudioSignal):
            x = getattr(x, self.attribute)
            y = getattr(y, self.attribute)
        return super().forward(x, y)


class SISDRLoss(nn.Module):
    """

    Computes the Scale-Invariant Source-to-Distortion Ratio between a batch

    of estimated and reference audio signals or aligned features.



    Parameters

    ----------

    scaling : int, optional

        Whether to use scale-invariant (True) or

        signal-to-noise ratio (False), by default True

    reduction : str, optional

        How to reduce across the batch (either 'mean',

        'sum', or none).], by default ' mean'

    zero_mean : int, optional

        Zero mean the references and estimates before

        computing the loss, by default True

    clip_min : int, optional

        The minimum possible loss value. Helps network

        to not focus on making already good examples better, by default None

    weight : float, optional

        Weight of this loss, defaults to 1.0.



    Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/distance.py

    """

    def __init__(

        self,

        scaling: int = True,

        reduction: str = "mean",

        zero_mean: int = True,

        clip_min: int = None,

        weight: float = 1.0,

    ):
        self.scaling = scaling
        self.reduction = reduction
        self.zero_mean = zero_mean
        self.clip_min = clip_min
        self.weight = weight
        super().__init__()

    def forward(self, x: AudioSignal, y: AudioSignal):
        eps = 1e-8
        # nb, nc, nt
        if isinstance(x, AudioSignal):
            references = x.audio_data
            estimates = y.audio_data
        else:
            references = x
            estimates = y

        nb = references.shape[0]
        references = references.reshape(nb, 1, -1).permute(0, 2, 1)
        estimates = estimates.reshape(nb, 1, -1).permute(0, 2, 1)

        # samples now on axis 1
        if self.zero_mean:
            mean_reference = references.mean(dim=1, keepdim=True)
            mean_estimate = estimates.mean(dim=1, keepdim=True)
        else:
            mean_reference = 0
            mean_estimate = 0

        _references = references - mean_reference
        _estimates = estimates - mean_estimate

        references_projection = (_references**2).sum(dim=-2) + eps
        references_on_estimates = (_estimates * _references).sum(dim=-2) + eps

        scale = (
            (references_on_estimates / references_projection).unsqueeze(1)
            if self.scaling
            else 1
        )

        e_true = scale * _references
        e_res = _estimates - e_true

        signal = (e_true**2).sum(dim=1)
        noise = (e_res**2).sum(dim=1)
        sdr = -10 * torch.log10(signal / noise + eps)

        if self.clip_min is not None:
            sdr = torch.clamp(sdr, min=self.clip_min)

        if self.reduction == "mean":
            sdr = sdr.mean()
        elif self.reduction == "sum":
            sdr = sdr.sum()
        return sdr


class MultiScaleSTFTLoss(nn.Module):
    """Computes the multi-scale STFT loss from [1].



    Parameters

    ----------

    window_lengths : List[int], optional

        Length of each window of each STFT, by default [2048, 512]

    loss_fn : typing.Callable, optional

        How to compare each loss, by default nn.L1Loss()

    clamp_eps : float, optional

        Clamp on the log magnitude, below, by default 1e-5

    mag_weight : float, optional

        Weight of raw magnitude portion of loss, by default 1.0

    log_weight : float, optional

        Weight of log magnitude portion of loss, by default 1.0

    pow : float, optional

        Power to raise magnitude to before taking log, by default 2.0

    weight : float, optional

        Weight of this loss, by default 1.0

    match_stride : bool, optional

        Whether to match the stride of convolutional layers, by default False



    References

    ----------



    1.  Engel, Jesse, Chenjie Gu, and Adam Roberts.

        "DDSP: Differentiable Digital Signal Processing."

        International Conference on Learning Representations. 2019.



    Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/spectral.py

    """

    def __init__(

        self,

        window_lengths: List[int] = [2048, 512],

        loss_fn: typing.Callable = nn.L1Loss(),

        clamp_eps: float = 1e-5,

        mag_weight: float = 1.0,

        log_weight: float = 1.0,

        pow: float = 2.0,

        weight: float = 1.0,

        match_stride: bool = False,

        window_type: str = None,

    ):
        super().__init__()
        self.stft_params = [
            STFTParams(
                window_length=w,
                hop_length=w // 4,
                match_stride=match_stride,
                window_type=window_type,
            )
            for w in window_lengths
        ]
        self.loss_fn = loss_fn
        self.log_weight = log_weight
        self.mag_weight = mag_weight
        self.clamp_eps = clamp_eps
        self.weight = weight
        self.pow = pow

    def forward(self, x: AudioSignal, y: AudioSignal):
        """Computes multi-scale STFT between an estimate and a reference

        signal.



        Parameters

        ----------

        x : AudioSignal

            Estimate signal

        y : AudioSignal

            Reference signal



        Returns

        -------

        torch.Tensor

            Multi-scale STFT loss.

        """
        loss = 0.0
        for s in self.stft_params:
            x.stft(s.window_length, s.hop_length, s.window_type)
            y.stft(s.window_length, s.hop_length, s.window_type)
            loss += self.log_weight * self.loss_fn(
                x.magnitude.clamp(self.clamp_eps).pow(self.pow).log10(),
                y.magnitude.clamp(self.clamp_eps).pow(self.pow).log10(),
            )
            loss += self.mag_weight * self.loss_fn(x.magnitude, y.magnitude)
        return loss


class MelSpectrogramLoss(nn.Module):
    """Compute distance between mel spectrograms. Can be used

    in a multi-scale way.



    Parameters

    ----------

    n_mels : List[int]

        Number of mels per STFT, by default [150, 80],

    window_lengths : List[int], optional

        Length of each window of each STFT, by default [2048, 512]

    loss_fn : typing.Callable, optional

        How to compare each loss, by default nn.L1Loss()

    clamp_eps : float, optional

        Clamp on the log magnitude, below, by default 1e-5

    mag_weight : float, optional

        Weight of raw magnitude portion of loss, by default 1.0

    log_weight : float, optional

        Weight of log magnitude portion of loss, by default 1.0

    pow : float, optional

        Power to raise magnitude to before taking log, by default 2.0

    weight : float, optional

        Weight of this loss, by default 1.0

    match_stride : bool, optional

        Whether to match the stride of convolutional layers, by default False



    Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/spectral.py

    """

    def __init__(

        self,

        n_mels: List[int] = [150, 80],

        window_lengths: List[int] = [2048, 512],

        loss_fn: typing.Callable = nn.L1Loss(),

        clamp_eps: float = 1e-5,

        mag_weight: float = 1.0,

        log_weight: float = 1.0,

        pow: float = 2.0,

        weight: float = 1.0,

        match_stride: bool = False,

        mel_fmin: List[float] = [0.0, 0.0],

        mel_fmax: List[float] = [None, None],

        window_type: str = None,

    ):
        super().__init__()
        self.stft_params = [
            STFTParams(
                window_length=w,
                hop_length=w // 4,
                match_stride=match_stride,
                window_type=window_type,
            )
            for w in window_lengths
        ]
        self.n_mels = n_mels
        self.loss_fn = loss_fn
        self.clamp_eps = clamp_eps
        self.log_weight = log_weight
        self.mag_weight = mag_weight
        self.weight = weight
        self.mel_fmin = mel_fmin
        self.mel_fmax = mel_fmax
        self.pow = pow

    def forward(self, x: AudioSignal, y: AudioSignal):
        """Computes mel loss between an estimate and a reference

        signal.



        Parameters

        ----------

        x : AudioSignal

            Estimate signal

        y : AudioSignal

            Reference signal



        Returns

        -------

        torch.Tensor

            Mel loss.

        """
        loss = 0.0
        for n_mels, fmin, fmax, s in zip(
            self.n_mels, self.mel_fmin, self.mel_fmax, self.stft_params
        ):
            kwargs = {
                "window_length": s.window_length,
                "hop_length": s.hop_length,
                "window_type": s.window_type,
            }
            x_mels = x.mel_spectrogram(n_mels, mel_fmin=fmin, mel_fmax=fmax, **kwargs)
            y_mels = y.mel_spectrogram(n_mels, mel_fmin=fmin, mel_fmax=fmax, **kwargs)

            loss += self.log_weight * self.loss_fn(
                x_mels.clamp(self.clamp_eps).pow(self.pow).log10(),
                y_mels.clamp(self.clamp_eps).pow(self.pow).log10(),
            )
            loss += self.mag_weight * self.loss_fn(x_mels, y_mels)
        return loss


class GANLoss(nn.Module):
    """

    Computes a discriminator loss, given a discriminator on

    generated waveforms/spectrograms compared to ground truth

    waveforms/spectrograms. Computes the loss for both the

    discriminator and the generator in separate functions.

    """

    def __init__(self, discriminator):
        super().__init__()
        self.discriminator = discriminator

    def forward(self, fake, real):
        d_fake = self.discriminator(fake.audio_data)
        d_real = self.discriminator(real.audio_data)
        return d_fake, d_real

    def discriminator_loss(self, fake, real):
        d_fake, d_real = self.forward(fake.clone().detach(), real)

        loss_d = 0
        for x_fake, x_real in zip(d_fake, d_real):
            loss_d += torch.mean(x_fake[-1] ** 2)
            loss_d += torch.mean((1 - x_real[-1]) ** 2)
        return loss_d

    def generator_loss(self, fake, real):
        d_fake, d_real = self.forward(fake, real)

        loss_g = 0
        for x_fake in d_fake:
            loss_g += torch.mean((1 - x_fake[-1]) ** 2)

        loss_feature = 0

        for i in range(len(d_fake)):
            for j in range(len(d_fake[i]) - 1):
                loss_feature += F.l1_loss(d_fake[i][j], d_real[i][j].detach())
        return loss_g, loss_feature