Spaces:
Running
Running
File size: 3,976 Bytes
f16472f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 |
#!/usr/bin/python3
# -*- coding: utf-8 -*-
from typing import Optional
import torch
from torch_pesq import PesqLoss
class Pesq(object):
def __init__(self):
pass
class CategoricalAccuracy(object):
def __init__(self, top_k: int = 1, tie_break: bool = False) -> None:
if top_k > 1 and tie_break:
raise AssertionError("Tie break in Categorical Accuracy "
"can be done only for maximum (top_k = 1)")
if top_k <= 0:
raise AssertionError("top_k passed to Categorical Accuracy must be > 0")
self._top_k = top_k
self._tie_break = tie_break
self.correct_count = 0.
self.total_count = 0.
def __call__(self,
predictions: torch.Tensor,
gold_labels: torch.Tensor,
mask: Optional[torch.Tensor] = None):
# predictions, gold_labels, mask = self.unwrap_to_tensors(predictions, gold_labels, mask)
# Some sanity checks.
num_classes = predictions.size(-1)
if gold_labels.dim() != predictions.dim() - 1:
raise AssertionError("gold_labels must have dimension == predictions.size() - 1 but "
"found tensor of shape: {}".format(predictions.size()))
if (gold_labels >= num_classes).any():
raise AssertionError("A gold label passed to Categorical Accuracy contains an id >= {}, "
"the number of classes.".format(num_classes))
predictions = predictions.view((-1, num_classes))
gold_labels = gold_labels.view(-1).long()
if not self._tie_break:
# Top K indexes of the predictions (or fewer, if there aren't K of them).
# Special case topk == 1, because it's common and .max() is much faster than .topk().
if self._top_k == 1:
top_k = predictions.max(-1)[1].unsqueeze(-1)
else:
top_k = predictions.topk(min(self._top_k, predictions.shape[-1]), -1)[1]
# This is of shape (batch_size, ..., top_k).
correct = top_k.eq(gold_labels.unsqueeze(-1)).float()
else:
# prediction is correct if gold label falls on any of the max scores. distribute score by tie_counts
max_predictions = predictions.max(-1)[0]
max_predictions_mask = predictions.eq(max_predictions.unsqueeze(-1))
# max_predictions_mask is (rows X num_classes) and gold_labels is (batch_size)
# ith entry in gold_labels points to index (0-num_classes) for ith row in max_predictions
# For each row check if index pointed by gold_label is was 1 or not (among max scored classes)
correct = max_predictions_mask[torch.arange(gold_labels.numel()).long(), gold_labels].float()
tie_counts = max_predictions_mask.sum(-1)
correct /= tie_counts.float()
correct.unsqueeze_(-1)
if mask is not None:
correct *= mask.view(-1, 1).float()
self.total_count += mask.sum()
else:
self.total_count += gold_labels.numel()
self.correct_count += correct.sum()
def get_metric(self, reset: bool = False):
"""
Returns
-------
The accumulated accuracy.
"""
if self.total_count > 1e-12:
accuracy = float(self.correct_count) / float(self.total_count)
else:
accuracy = 0.0
if reset:
self.reset()
return {'accuracy': accuracy}
def reset(self):
self.correct_count = 0.0
self.total_count = 0.0
def main():
pesq = PesqLoss(0.5,
sample_rate=8000,
)
reference = torch.randn(1, 44100)
degraded = torch.randn(1, 44100)
mos = pesq.mos(reference, degraded)
loss = pesq(reference, degraded)
print(mos, loss)
return
if __name__ == '__main__':
main()
|