|
import torch |
|
from torch import nn |
|
|
|
from .. import AudioSignal |
|
|
|
|
|
class L1Loss(nn.L1Loss): |
|
"""L1 Loss between AudioSignals. Defaults |
|
to comparing ``audio_data``, but any |
|
attribute of an AudioSignal can be used. |
|
|
|
Parameters |
|
---------- |
|
attribute : str, optional |
|
Attribute of signal to compare, defaults to ``audio_data``. |
|
weight : float, optional |
|
Weight of this loss, defaults to 1.0. |
|
""" |
|
|
|
def __init__(self, attribute: str = "audio_data", weight: float = 1.0, **kwargs): |
|
self.attribute = attribute |
|
self.weight = weight |
|
super().__init__(**kwargs) |
|
|
|
def forward(self, x: AudioSignal, y: AudioSignal): |
|
""" |
|
Parameters |
|
---------- |
|
x : AudioSignal |
|
Estimate AudioSignal |
|
y : AudioSignal |
|
Reference AudioSignal |
|
|
|
Returns |
|
------- |
|
torch.Tensor |
|
L1 loss between AudioSignal attributes. |
|
""" |
|
if isinstance(x, AudioSignal): |
|
x = getattr(x, self.attribute) |
|
y = getattr(y, self.attribute) |
|
return super().forward(x, y) |
|
|
|
|
|
class SISDRLoss(nn.Module): |
|
""" |
|
Computes the Scale-Invariant Source-to-Distortion Ratio between a batch |
|
of estimated and reference audio signals or aligned features. |
|
|
|
Parameters |
|
---------- |
|
scaling : int, optional |
|
Whether to use scale-invariant (True) or |
|
signal-to-noise ratio (False), by default True |
|
reduction : str, optional |
|
How to reduce across the batch (either 'mean', |
|
'sum', or none).], by default ' mean' |
|
zero_mean : int, optional |
|
Zero mean the references and estimates before |
|
computing the loss, by default True |
|
clip_min : int, optional |
|
The minimum possible loss value. Helps network |
|
to not focus on making already good examples better, by default None |
|
weight : float, optional |
|
Weight of this loss, defaults to 1.0. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
scaling: int = True, |
|
reduction: str = "mean", |
|
zero_mean: int = True, |
|
clip_min: int = None, |
|
weight: float = 1.0, |
|
): |
|
self.scaling = scaling |
|
self.reduction = reduction |
|
self.zero_mean = zero_mean |
|
self.clip_min = clip_min |
|
self.weight = weight |
|
super().__init__() |
|
|
|
def forward(self, x: AudioSignal, y: AudioSignal): |
|
eps = 1e-8 |
|
|
|
if isinstance(x, AudioSignal): |
|
references = x.audio_data |
|
estimates = y.audio_data |
|
else: |
|
references = x |
|
estimates = y |
|
|
|
nb = references.shape[0] |
|
references = references.reshape(nb, 1, -1).permute(0, 2, 1) |
|
estimates = estimates.reshape(nb, 1, -1).permute(0, 2, 1) |
|
|
|
|
|
if self.zero_mean: |
|
mean_reference = references.mean(dim=1, keepdim=True) |
|
mean_estimate = estimates.mean(dim=1, keepdim=True) |
|
else: |
|
mean_reference = 0 |
|
mean_estimate = 0 |
|
|
|
_references = references - mean_reference |
|
_estimates = estimates - mean_estimate |
|
|
|
references_projection = (_references**2).sum(dim=-2) + eps |
|
references_on_estimates = (_estimates * _references).sum(dim=-2) + eps |
|
|
|
scale = ( |
|
(references_on_estimates / references_projection).unsqueeze(1) |
|
if self.scaling |
|
else 1 |
|
) |
|
|
|
e_true = scale * _references |
|
e_res = _estimates - e_true |
|
|
|
signal = (e_true**2).sum(dim=1) |
|
noise = (e_res**2).sum(dim=1) |
|
sdr = -10 * torch.log10(signal / noise + eps) |
|
|
|
if self.clip_min is not None: |
|
sdr = torch.clamp(sdr, min=self.clip_min) |
|
|
|
if self.reduction == "mean": |
|
sdr = sdr.mean() |
|
elif self.reduction == "sum": |
|
sdr = sdr.sum() |
|
return sdr |
|
|