#!/usr/bin/python3 # -*- coding: utf-8 -*- from typing import Optional import torch from torch_pesq import PesqLoss class Pesq(object): def __init__(self): pass class CategoricalAccuracy(object): def __init__(self, top_k: int = 1, tie_break: bool = False) -> None: if top_k > 1 and tie_break: raise AssertionError("Tie break in Categorical Accuracy " "can be done only for maximum (top_k = 1)") if top_k <= 0: raise AssertionError("top_k passed to Categorical Accuracy must be > 0") self._top_k = top_k self._tie_break = tie_break self.correct_count = 0. self.total_count = 0. def __call__(self, predictions: torch.Tensor, gold_labels: torch.Tensor, mask: Optional[torch.Tensor] = None): # predictions, gold_labels, mask = self.unwrap_to_tensors(predictions, gold_labels, mask) # Some sanity checks. num_classes = predictions.size(-1) if gold_labels.dim() != predictions.dim() - 1: raise AssertionError("gold_labels must have dimension == predictions.size() - 1 but " "found tensor of shape: {}".format(predictions.size())) if (gold_labels >= num_classes).any(): raise AssertionError("A gold label passed to Categorical Accuracy contains an id >= {}, " "the number of classes.".format(num_classes)) predictions = predictions.view((-1, num_classes)) gold_labels = gold_labels.view(-1).long() if not self._tie_break: # Top K indexes of the predictions (or fewer, if there aren't K of them). # Special case topk == 1, because it's common and .max() is much faster than .topk(). if self._top_k == 1: top_k = predictions.max(-1)[1].unsqueeze(-1) else: top_k = predictions.topk(min(self._top_k, predictions.shape[-1]), -1)[1] # This is of shape (batch_size, ..., top_k). correct = top_k.eq(gold_labels.unsqueeze(-1)).float() else: # prediction is correct if gold label falls on any of the max scores. distribute score by tie_counts max_predictions = predictions.max(-1)[0] max_predictions_mask = predictions.eq(max_predictions.unsqueeze(-1)) # max_predictions_mask is (rows X num_classes) and gold_labels is (batch_size) # ith entry in gold_labels points to index (0-num_classes) for ith row in max_predictions # For each row check if index pointed by gold_label is was 1 or not (among max scored classes) correct = max_predictions_mask[torch.arange(gold_labels.numel()).long(), gold_labels].float() tie_counts = max_predictions_mask.sum(-1) correct /= tie_counts.float() correct.unsqueeze_(-1) if mask is not None: correct *= mask.view(-1, 1).float() self.total_count += mask.sum() else: self.total_count += gold_labels.numel() self.correct_count += correct.sum() def get_metric(self, reset: bool = False): """ Returns ------- The accumulated accuracy. """ if self.total_count > 1e-12: accuracy = float(self.correct_count) / float(self.total_count) else: accuracy = 0.0 if reset: self.reset() return {'accuracy': accuracy} def reset(self): self.correct_count = 0.0 self.total_count = 0.0 def main(): pesq = PesqLoss(0.5, sample_rate=8000, ) reference = torch.randn(1, 44100) degraded = torch.randn(1, 44100) mos = pesq.mos(reference, degraded) loss = pesq(reference, degraded) print(mos, loss) return if __name__ == '__main__': main()