Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,906 Bytes
71de706 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
import torch
from torch import nn
from .. import AudioSignal
class L1Loss(nn.L1Loss):
"""L1 Loss between AudioSignals. Defaults
to comparing ``audio_data``, but any
attribute of an AudioSignal can be used.
Parameters
----------
attribute : str, optional
Attribute of signal to compare, defaults to ``audio_data``.
weight : float, optional
Weight of this loss, defaults to 1.0.
"""
def __init__(self, attribute: str = "audio_data", weight: float = 1.0, **kwargs):
self.attribute = attribute
self.weight = weight
super().__init__(**kwargs)
def forward(self, x: AudioSignal, y: AudioSignal):
"""
Parameters
----------
x : AudioSignal
Estimate AudioSignal
y : AudioSignal
Reference AudioSignal
Returns
-------
torch.Tensor
L1 loss between AudioSignal attributes.
"""
if isinstance(x, AudioSignal):
x = getattr(x, self.attribute)
y = getattr(y, self.attribute)
return super().forward(x, y)
class SISDRLoss(nn.Module):
"""
Computes the Scale-Invariant Source-to-Distortion Ratio between a batch
of estimated and reference audio signals or aligned features.
Parameters
----------
scaling : int, optional
Whether to use scale-invariant (True) or
signal-to-noise ratio (False), by default True
reduction : str, optional
How to reduce across the batch (either 'mean',
'sum', or none).], by default ' mean'
zero_mean : int, optional
Zero mean the references and estimates before
computing the loss, by default True
clip_min : int, optional
The minimum possible loss value. Helps network
to not focus on making already good examples better, by default None
weight : float, optional
Weight of this loss, defaults to 1.0.
"""
def __init__(
self,
scaling: int = True,
reduction: str = "mean",
zero_mean: int = True,
clip_min: int = None,
weight: float = 1.0,
):
self.scaling = scaling
self.reduction = reduction
self.zero_mean = zero_mean
self.clip_min = clip_min
self.weight = weight
super().__init__()
def forward(self, x: AudioSignal, y: AudioSignal):
eps = 1e-8
# nb, nc, nt
if isinstance(x, AudioSignal):
references = x.audio_data
estimates = y.audio_data
else:
references = x
estimates = y
nb = references.shape[0]
references = references.reshape(nb, 1, -1).permute(0, 2, 1)
estimates = estimates.reshape(nb, 1, -1).permute(0, 2, 1)
# samples now on axis 1
if self.zero_mean:
mean_reference = references.mean(dim=1, keepdim=True)
mean_estimate = estimates.mean(dim=1, keepdim=True)
else:
mean_reference = 0
mean_estimate = 0
_references = references - mean_reference
_estimates = estimates - mean_estimate
references_projection = (_references**2).sum(dim=-2) + eps
references_on_estimates = (_estimates * _references).sum(dim=-2) + eps
scale = (
(references_on_estimates / references_projection).unsqueeze(1)
if self.scaling
else 1
)
e_true = scale * _references
e_res = _estimates - e_true
signal = (e_true**2).sum(dim=1)
noise = (e_res**2).sum(dim=1)
sdr = -10 * torch.log10(signal / noise + eps)
if self.clip_min is not None:
sdr = torch.clamp(sdr, min=self.clip_min)
if self.reduction == "mean":
sdr = sdr.mean()
elif self.reduction == "sum":
sdr = sdr.sum()
return sdr
|