balthou's picture
draft audio sep app
f6b56a2
from gyraudio.audio_separation.properties import SIGNAL, NOISE, TOTAL, LOSS_TYPE, COEFFICIENT, SNR
import torch
def snr(prediction: torch.Tensor, ground_truth: torch.Tensor, reduce="mean") -> torch.Tensor:
"""Compute the SNR between two tensors.
Args:
prediction (torch.Tensor): prediction tensor
ground_truth (torch.Tensor): ground truth tensor
Returns:
torch.Tensor: SNR
"""
power_signal = torch.sum(ground_truth**2, dim=(-2, -1))
power_error = torch.sum((prediction-ground_truth)**2, dim=(-2, -1))
eps = torch.finfo(torch.float32).eps
snr_per_element = 10*torch.log10((power_signal+eps)/(power_error+eps))
final_snr = torch.mean(snr_per_element) if reduce == "mean" else snr_per_element
return final_snr
DEFAULT_COST = {
SIGNAL: {
COEFFICIENT: 0.5,
LOSS_TYPE: torch.nn.functional.mse_loss
},
NOISE: {
COEFFICIENT: 0.5,
LOSS_TYPE: torch.nn.functional.mse_loss
},
SNR: {
LOSS_TYPE: snr
}
}
class Costs:
"""Keep track of cost functions.
```
for epoch in range(...):
metric.reset_epoch()
for step in dataloader(...):
... # forward
prediction = model.forward(batch)
metric.update(prediction1, groundtruth1, SIGNAL1)
metric.update(prediction2, groundtruth2, SIGNAL2)
loss = metric.finish_step()
loss.backward()
... # backprop
metric.finish_epoch()
... # log metrics
```
"""
def __init__(self, name: str, costs=DEFAULT_COST) -> None:
self.name = name
self.keys = list(costs.keys())
self.cost = costs
def __reset_step(self) -> None:
self.metrics = {key: 0. for key in self.keys}
def reset_epoch(self) -> None:
self.__reset_step()
self.total_metric = {key: 0. for key in self.keys+[TOTAL]}
self.count = 0
def update(self,
prediction: torch.Tensor,
ground_truth: torch.Tensor,
key: str
) -> torch.Tensor:
assert key != TOTAL
# Compute loss for a single batch (=step)
loss_signal = self.cost[key][LOSS_TYPE](prediction, ground_truth)
self.metrics[key] = loss_signal
def finish_step(self) -> torch.Tensor:
# Reset current total
self.metrics[TOTAL] = 0.
# Sum all metrics to total
for key in self.metrics:
if key != TOTAL and self.cost[key].get(COEFFICIENT, False):
self.metrics[TOTAL] += self.cost[key][COEFFICIENT]*self.metrics[key]
loss_signal = self.metrics[TOTAL]
for key in self.metrics:
if not isinstance(self.metrics[key], float):
self.metrics[key] = self.metrics[key].item()
self.total_metric[key] += self.metrics[key]
self.count += 1
return loss_signal
def finish_epoch(self) -> None:
for key in self.metrics:
self.total_metric[key] /= self.count
def __repr__(self) -> str:
rep = f"{self.name}\t:\t"
for key in self.total_metric:
rep += f"{key}: {self.total_metric[key]:.3e} | "
return rep