Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
from torch import nn | |
from .. import AudioSignal | |
class L1Loss(nn.L1Loss): | |
"""L1 Loss between AudioSignals. Defaults | |
to comparing ``audio_data``, but any | |
attribute of an AudioSignal can be used. | |
Parameters | |
---------- | |
attribute : str, optional | |
Attribute of signal to compare, defaults to ``audio_data``. | |
weight : float, optional | |
Weight of this loss, defaults to 1.0. | |
""" | |
def __init__(self, attribute: str = "audio_data", weight: float = 1.0, **kwargs): | |
self.attribute = attribute | |
self.weight = weight | |
super().__init__(**kwargs) | |
def forward(self, x: AudioSignal, y: AudioSignal): | |
""" | |
Parameters | |
---------- | |
x : AudioSignal | |
Estimate AudioSignal | |
y : AudioSignal | |
Reference AudioSignal | |
Returns | |
------- | |
torch.Tensor | |
L1 loss between AudioSignal attributes. | |
""" | |
if isinstance(x, AudioSignal): | |
x = getattr(x, self.attribute) | |
y = getattr(y, self.attribute) | |
return super().forward(x, y) | |
class SISDRLoss(nn.Module): | |
""" | |
Computes the Scale-Invariant Source-to-Distortion Ratio between a batch | |
of estimated and reference audio signals or aligned features. | |
Parameters | |
---------- | |
scaling : int, optional | |
Whether to use scale-invariant (True) or | |
signal-to-noise ratio (False), by default True | |
reduction : str, optional | |
How to reduce across the batch (either 'mean', | |
'sum', or none).], by default ' mean' | |
zero_mean : int, optional | |
Zero mean the references and estimates before | |
computing the loss, by default True | |
clip_min : int, optional | |
The minimum possible loss value. Helps network | |
to not focus on making already good examples better, by default None | |
weight : float, optional | |
Weight of this loss, defaults to 1.0. | |
""" | |
def __init__( | |
self, | |
scaling: int = True, | |
reduction: str = "mean", | |
zero_mean: int = True, | |
clip_min: int = None, | |
weight: float = 1.0, | |
): | |
self.scaling = scaling | |
self.reduction = reduction | |
self.zero_mean = zero_mean | |
self.clip_min = clip_min | |
self.weight = weight | |
super().__init__() | |
def forward(self, x: AudioSignal, y: AudioSignal): | |
eps = 1e-8 | |
# nb, nc, nt | |
if isinstance(x, AudioSignal): | |
references = x.audio_data | |
estimates = y.audio_data | |
else: | |
references = x | |
estimates = y | |
nb = references.shape[0] | |
references = references.reshape(nb, 1, -1).permute(0, 2, 1) | |
estimates = estimates.reshape(nb, 1, -1).permute(0, 2, 1) | |
# samples now on axis 1 | |
if self.zero_mean: | |
mean_reference = references.mean(dim=1, keepdim=True) | |
mean_estimate = estimates.mean(dim=1, keepdim=True) | |
else: | |
mean_reference = 0 | |
mean_estimate = 0 | |
_references = references - mean_reference | |
_estimates = estimates - mean_estimate | |
references_projection = (_references**2).sum(dim=-2) + eps | |
references_on_estimates = (_estimates * _references).sum(dim=-2) + eps | |
scale = ( | |
(references_on_estimates / references_projection).unsqueeze(1) | |
if self.scaling | |
else 1 | |
) | |
e_true = scale * _references | |
e_res = _estimates - e_true | |
signal = (e_true**2).sum(dim=1) | |
noise = (e_res**2).sum(dim=1) | |
sdr = -10 * torch.log10(signal / noise + eps) | |
if self.clip_min is not None: | |
sdr = torch.clamp(sdr, min=self.clip_min) | |
if self.reduction == "mean": | |
sdr = sdr.mean() | |
elif self.reduction == "sum": | |
sdr = sdr.sum() | |
return sdr | |