Last commit not found
# Copyright (c) Facebook, Inc. and its affiliates. | |
# -------------------------------------------------------- | |
# X-Decoder -- Generalized Decoding for Pixel, Image, and Language | |
# Copyright (c) 2022 Microsoft | |
# Licensed under The MIT License [see LICENSE for details] | |
# Modified by Xueyan Zou ([email protected]) | |
# -------------------------------------------------------- | |
import torch | |
import logging | |
from detectron2.evaluation.evaluator import DatasetEvaluator | |
from utilities.misc import AverageMeter | |
from utilities.distributed import get_world_size | |
def accuracy(output, target, topk=(1,)): | |
"""Computes the precision@k for the specified values of k""" | |
if isinstance(output, list): | |
output = output[-1] | |
n_classes = output.size()[1] | |
maxk = min(max(topk), n_classes) | |
batch_size = target.size(0) | |
_, pred = output.topk(maxk, 1, True, True) | |
pred = pred.t() | |
correct = pred.eq(target.reshape(1, -1).expand_as(pred)) | |
res = [] | |
for k in topk: | |
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) | |
res.append(correct_k.mul_(100.0 / batch_size).item()) | |
return res | |
class ClassificationEvaluator(DatasetEvaluator): | |
def __init__(self, *args): | |
self.top1 = AverageMeter() | |
self.top5 = AverageMeter() | |
self._logger = logging.getLogger(__name__) | |
def reset(self): | |
self.top1.reset() | |
self.top5.reset() | |
def process(self, inputs, outputs): | |
logits = torch.stack([o['pred_class'] for o in outputs]) | |
y = torch.tensor([t['class_id'] for t in inputs], device=logits.device) | |
prec1, prec5 = accuracy(logits, y, (1, 5)) | |
self.top1.update(prec1, y.size(0)) | |
self.top5.update(prec5, y.size(0)) | |
def evaluate(self): | |
if get_world_size() > 1: | |
tmp_tensor = torch.tensor( | |
[self.top1.sum, self.top5.sum, self.top1.count], | |
device=torch.cuda.current_device() | |
) | |
torch.distributed.all_reduce( | |
tmp_tensor, torch.distributed.ReduceOp.SUM | |
) | |
top1_sum, top5_sum, count = tmp_tensor.tolist() | |
else: | |
top1_sum = self.top1.sum | |
top5_sum = self.top5.sum | |
count = self.top1.count | |
results = {} | |
scores = { | |
'top1': top1_sum / count, | |
"top5": top5_sum / count | |
} | |
results['class'] = scores | |
self._logger.info(results) | |
return results | |