Spaces:
Sleeping
Sleeping
import torch | |
from ignite.metrics import Accuracy, Loss | |
from typing import Sequence | |
class TopKAccuracy(Accuracy): | |
def update(self, output: Sequence[torch.Tensor], **kwargs) -> None: | |
y_pred, y_attack = output[0].detach(), output[1].detach() | |
k = y_attack.shape[-1] | |
y_pred_indices = y_pred.argsort(dim=-1, descending=True) # [N, C] | |
correct = (y_pred_indices[:, :k] == y_attack).all(dim=-1) | |
self._num_correct += torch.sum(correct).to(self._device) | |
self._num_examples += correct.shape[0] |