File size: 1,228 Bytes
32b542e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import torch
import torch.nn as nn
import torch.nn.functional as F
from uniperceiver.config import configurable
from .build import LOSSES_REGISTRY

@LOSSES_REGISTRY.register()
class Accuracy(nn.Module):
    @configurable
    def __init__(
        self
    ):
        super(Accuracy, self).__init__()

    @classmethod
    def from_config(cls, cfg):
        return {
        }

    def Forward(self, logits, targets):
        pred = torch.argmax(logits.view(-1, logits.shape[-1]), -1)
        targets = targets.view(-1)
        return torch.mean((pred == targets).float())

    def forward(self, outputs_dict):

        ret = {}
        for logit, target, loss_identification in zip(outputs_dict['logits'],
                                            outputs_dict['targets'],
                                            outputs_dict['loss_names']):
            if logit.shape == target.shape:
                # for mixup
                target = torch.argmax(target, dim=-1)
            acc = self.Forward(logit, target)
            loss_name = 'Accuracy'
            if len(loss_identification) > 0:
                loss_name = loss_name + f' ({loss_identification})'
            ret.update({loss_name: acc})

        return ret