Spaces:
Running
Running
#!/usr/bin/python3 | |
# -*- coding: utf-8 -*- | |
from typing import Optional | |
import torch | |
from torch_pesq import PesqLoss | |
class Pesq(object): | |
def __init__(self): | |
pass | |
class CategoricalAccuracy(object): | |
def __init__(self, top_k: int = 1, tie_break: bool = False) -> None: | |
if top_k > 1 and tie_break: | |
raise AssertionError("Tie break in Categorical Accuracy " | |
"can be done only for maximum (top_k = 1)") | |
if top_k <= 0: | |
raise AssertionError("top_k passed to Categorical Accuracy must be > 0") | |
self._top_k = top_k | |
self._tie_break = tie_break | |
self.correct_count = 0. | |
self.total_count = 0. | |
def __call__(self, | |
predictions: torch.Tensor, | |
gold_labels: torch.Tensor, | |
mask: Optional[torch.Tensor] = None): | |
# predictions, gold_labels, mask = self.unwrap_to_tensors(predictions, gold_labels, mask) | |
# Some sanity checks. | |
num_classes = predictions.size(-1) | |
if gold_labels.dim() != predictions.dim() - 1: | |
raise AssertionError("gold_labels must have dimension == predictions.size() - 1 but " | |
"found tensor of shape: {}".format(predictions.size())) | |
if (gold_labels >= num_classes).any(): | |
raise AssertionError("A gold label passed to Categorical Accuracy contains an id >= {}, " | |
"the number of classes.".format(num_classes)) | |
predictions = predictions.view((-1, num_classes)) | |
gold_labels = gold_labels.view(-1).long() | |
if not self._tie_break: | |
# Top K indexes of the predictions (or fewer, if there aren't K of them). | |
# Special case topk == 1, because it's common and .max() is much faster than .topk(). | |
if self._top_k == 1: | |
top_k = predictions.max(-1)[1].unsqueeze(-1) | |
else: | |
top_k = predictions.topk(min(self._top_k, predictions.shape[-1]), -1)[1] | |
# This is of shape (batch_size, ..., top_k). | |
correct = top_k.eq(gold_labels.unsqueeze(-1)).float() | |
else: | |
# prediction is correct if gold label falls on any of the max scores. distribute score by tie_counts | |
max_predictions = predictions.max(-1)[0] | |
max_predictions_mask = predictions.eq(max_predictions.unsqueeze(-1)) | |
# max_predictions_mask is (rows X num_classes) and gold_labels is (batch_size) | |
# ith entry in gold_labels points to index (0-num_classes) for ith row in max_predictions | |
# For each row check if index pointed by gold_label is was 1 or not (among max scored classes) | |
correct = max_predictions_mask[torch.arange(gold_labels.numel()).long(), gold_labels].float() | |
tie_counts = max_predictions_mask.sum(-1) | |
correct /= tie_counts.float() | |
correct.unsqueeze_(-1) | |
if mask is not None: | |
correct *= mask.view(-1, 1).float() | |
self.total_count += mask.sum() | |
else: | |
self.total_count += gold_labels.numel() | |
self.correct_count += correct.sum() | |
def get_metric(self, reset: bool = False): | |
""" | |
Returns | |
------- | |
The accumulated accuracy. | |
""" | |
if self.total_count > 1e-12: | |
accuracy = float(self.correct_count) / float(self.total_count) | |
else: | |
accuracy = 0.0 | |
if reset: | |
self.reset() | |
return {'accuracy': accuracy} | |
def reset(self): | |
self.correct_count = 0.0 | |
self.total_count = 0.0 | |
def main(): | |
pesq = PesqLoss(0.5, | |
sample_rate=8000, | |
) | |
reference = torch.randn(1, 44100) | |
degraded = torch.randn(1, 44100) | |
mos = pesq.mos(reference, degraded) | |
loss = pesq(reference, degraded) | |
print(mos, loss) | |
return | |
if __name__ == '__main__': | |
main() | |