Spaces:
Sleeping
Sleeping
File size: 376 Bytes
e45d058 |
1 2 3 4 5 6 7 8 9 10 11 12 |
import torch
from torch import Tensor
from torchmetrics import Metric, Accuracy
class AccuracyMine(Accuracy):
"""Wrap torchmetrics.Accuracy to take argmax of y in case of Mixup.
"""
def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore
super().update(preds, target.argmax(dim=-1) if target.is_floating_point() else target)
|