|
import typing |
|
from typing import List |
|
|
|
import numpy as np |
|
from torch import nn |
|
|
|
from .. import AudioSignal |
|
from .. import STFTParams |
|
|
|
|
|
class MultiScaleSTFTLoss(nn.Module): |
|
"""Computes the multi-scale STFT loss from [1]. |
|
|
|
Parameters |
|
---------- |
|
window_lengths : List[int], optional |
|
Length of each window of each STFT, by default [2048, 512] |
|
loss_fn : typing.Callable, optional |
|
How to compare each loss, by default nn.L1Loss() |
|
clamp_eps : float, optional |
|
Clamp on the log magnitude, below, by default 1e-5 |
|
mag_weight : float, optional |
|
Weight of raw magnitude portion of loss, by default 1.0 |
|
log_weight : float, optional |
|
Weight of log magnitude portion of loss, by default 1.0 |
|
pow : float, optional |
|
Power to raise magnitude to before taking log, by default 2.0 |
|
weight : float, optional |
|
Weight of this loss, by default 1.0 |
|
match_stride : bool, optional |
|
Whether to match the stride of convolutional layers, by default False |
|
|
|
References |
|
---------- |
|
|
|
1. Engel, Jesse, Chenjie Gu, and Adam Roberts. |
|
"DDSP: Differentiable Digital Signal Processing." |
|
International Conference on Learning Representations. 2019. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
window_lengths: List[int] = [2048, 512], |
|
loss_fn: typing.Callable = nn.L1Loss(), |
|
clamp_eps: float = 1e-5, |
|
mag_weight: float = 1.0, |
|
log_weight: float = 1.0, |
|
pow: float = 2.0, |
|
weight: float = 1.0, |
|
match_stride: bool = False, |
|
window_type: str = None, |
|
): |
|
super().__init__() |
|
self.stft_params = [ |
|
STFTParams( |
|
window_length=w, |
|
hop_length=w // 4, |
|
match_stride=match_stride, |
|
window_type=window_type, |
|
) |
|
for w in window_lengths |
|
] |
|
self.loss_fn = loss_fn |
|
self.log_weight = log_weight |
|
self.mag_weight = mag_weight |
|
self.clamp_eps = clamp_eps |
|
self.weight = weight |
|
self.pow = pow |
|
|
|
def forward(self, x: AudioSignal, y: AudioSignal): |
|
"""Computes multi-scale STFT between an estimate and a reference |
|
signal. |
|
|
|
Parameters |
|
---------- |
|
x : AudioSignal |
|
Estimate signal |
|
y : AudioSignal |
|
Reference signal |
|
|
|
Returns |
|
------- |
|
torch.Tensor |
|
Multi-scale STFT loss. |
|
""" |
|
loss = 0.0 |
|
for s in self.stft_params: |
|
x.stft(s.window_length, s.hop_length, s.window_type) |
|
y.stft(s.window_length, s.hop_length, s.window_type) |
|
loss += self.log_weight * self.loss_fn( |
|
x.magnitude.clamp(self.clamp_eps).pow(self.pow).log10(), |
|
y.magnitude.clamp(self.clamp_eps).pow(self.pow).log10(), |
|
) |
|
loss += self.mag_weight * self.loss_fn(x.magnitude, y.magnitude) |
|
return loss |
|
|
|
|
|
class MelSpectrogramLoss(nn.Module): |
|
"""Compute distance between mel spectrograms. Can be used |
|
in a multi-scale way. |
|
|
|
Parameters |
|
---------- |
|
n_mels : List[int] |
|
Number of mels per STFT, by default [150, 80], |
|
window_lengths : List[int], optional |
|
Length of each window of each STFT, by default [2048, 512] |
|
loss_fn : typing.Callable, optional |
|
How to compare each loss, by default nn.L1Loss() |
|
clamp_eps : float, optional |
|
Clamp on the log magnitude, below, by default 1e-5 |
|
mag_weight : float, optional |
|
Weight of raw magnitude portion of loss, by default 1.0 |
|
log_weight : float, optional |
|
Weight of log magnitude portion of loss, by default 1.0 |
|
pow : float, optional |
|
Power to raise magnitude to before taking log, by default 2.0 |
|
weight : float, optional |
|
Weight of this loss, by default 1.0 |
|
match_stride : bool, optional |
|
Whether to match the stride of convolutional layers, by default False |
|
""" |
|
|
|
def __init__( |
|
self, |
|
n_mels: List[int] = [150, 80], |
|
window_lengths: List[int] = [2048, 512], |
|
loss_fn: typing.Callable = nn.L1Loss(), |
|
clamp_eps: float = 1e-5, |
|
mag_weight: float = 1.0, |
|
log_weight: float = 1.0, |
|
pow: float = 2.0, |
|
weight: float = 1.0, |
|
match_stride: bool = False, |
|
mel_fmin: List[float] = [0.0, 0.0], |
|
mel_fmax: List[float] = [None, None], |
|
window_type: str = None, |
|
): |
|
super().__init__() |
|
self.stft_params = [ |
|
STFTParams( |
|
window_length=w, |
|
hop_length=w // 4, |
|
match_stride=match_stride, |
|
window_type=window_type, |
|
) |
|
for w in window_lengths |
|
] |
|
self.n_mels = n_mels |
|
self.loss_fn = loss_fn |
|
self.clamp_eps = clamp_eps |
|
self.log_weight = log_weight |
|
self.mag_weight = mag_weight |
|
self.weight = weight |
|
self.mel_fmin = mel_fmin |
|
self.mel_fmax = mel_fmax |
|
self.pow = pow |
|
|
|
def forward(self, x: AudioSignal, y: AudioSignal): |
|
"""Computes mel loss between an estimate and a reference |
|
signal. |
|
|
|
Parameters |
|
---------- |
|
x : AudioSignal |
|
Estimate signal |
|
y : AudioSignal |
|
Reference signal |
|
|
|
Returns |
|
------- |
|
torch.Tensor |
|
Mel loss. |
|
""" |
|
loss = 0.0 |
|
for n_mels, fmin, fmax, s in zip( |
|
self.n_mels, self.mel_fmin, self.mel_fmax, self.stft_params |
|
): |
|
kwargs = { |
|
"window_length": s.window_length, |
|
"hop_length": s.hop_length, |
|
"window_type": s.window_type, |
|
} |
|
x_mels = x.mel_spectrogram(n_mels, mel_fmin=fmin, mel_fmax=fmax, **kwargs) |
|
y_mels = y.mel_spectrogram(n_mels, mel_fmin=fmin, mel_fmax=fmax, **kwargs) |
|
|
|
loss += self.log_weight * self.loss_fn( |
|
x_mels.clamp(self.clamp_eps).pow(self.pow).log10(), |
|
y_mels.clamp(self.clamp_eps).pow(self.pow).log10(), |
|
) |
|
loss += self.mag_weight * self.loss_fn(x_mels, y_mels) |
|
return loss |
|
|
|
|
|
class PhaseLoss(nn.Module): |
|
"""Difference between phase spectrograms. |
|
|
|
Parameters |
|
---------- |
|
window_length : int, optional |
|
Length of STFT window, by default 2048 |
|
hop_length : int, optional |
|
Hop length of STFT window, by default 512 |
|
weight : float, optional |
|
Weight of loss, by default 1.0 |
|
""" |
|
|
|
def __init__( |
|
self, window_length: int = 2048, hop_length: int = 512, weight: float = 1.0 |
|
): |
|
super().__init__() |
|
|
|
self.weight = weight |
|
self.stft_params = STFTParams(window_length, hop_length) |
|
|
|
def forward(self, x: AudioSignal, y: AudioSignal): |
|
"""Computes phase loss between an estimate and a reference |
|
signal. |
|
|
|
Parameters |
|
---------- |
|
x : AudioSignal |
|
Estimate signal |
|
y : AudioSignal |
|
Reference signal |
|
|
|
Returns |
|
------- |
|
torch.Tensor |
|
Phase loss. |
|
""" |
|
s = self.stft_params |
|
x.stft(s.window_length, s.hop_length, s.window_type) |
|
y.stft(s.window_length, s.hop_length, s.window_type) |
|
|
|
|
|
diff = x.phase - y.phase |
|
diff[diff < -np.pi] += 2 * np.pi |
|
diff[diff > np.pi] -= -2 * np.pi |
|
|
|
|
|
x_min, x_max = x.magnitude.min(), x.magnitude.max() |
|
weights = (x.magnitude - x_min) / (x_max - x_min) |
|
|
|
|
|
loss = ((weights * diff) ** 2).mean() |
|
return loss |
|
|