pengdadaaa's picture
Upload 741 files
786f6a6 verified
raw
history blame
1.8 kB
""" Real labels evaluator for ImageNet
Paper: `Are we done with ImageNet?` - https://arxiv.org/abs/2006.07159
Based on Numpy example at https://github.com/google-research/reassessed-imagenet
Hacked together by / Copyright 2020 Ross Wightman
"""
import os
import json
import numpy as np
import pkgutil
class RealLabelsImagenet:
def __init__(self, filenames, real_json=None, topk=(1, 5)):
if real_json is not None:
with open(real_json) as real_labels:
real_labels = json.load(real_labels)
else:
real_labels = json.loads(
pkgutil.get_data(__name__, os.path.join('_info', 'imagenet_real_labels.json')).decode('utf-8'))
real_labels = {f'ILSVRC2012_val_{i + 1:08d}.JPEG': labels for i, labels in enumerate(real_labels)}
self.real_labels = real_labels
self.filenames = filenames
assert len(self.filenames) == len(self.real_labels)
self.topk = topk
self.is_correct = {k: [] for k in topk}
self.sample_idx = 0
def add_result(self, output):
maxk = max(self.topk)
_, pred_batch = output.topk(maxk, 1, True, True)
pred_batch = pred_batch.cpu().numpy()
for pred in pred_batch:
filename = self.filenames[self.sample_idx]
filename = os.path.basename(filename)
if self.real_labels[filename]:
for k in self.topk:
self.is_correct[k].append(
any([p in self.real_labels[filename] for p in pred[:k]]))
self.sample_idx += 1
def get_accuracy(self, k=None):
if k is None:
return {k: float(np.mean(self.is_correct[k])) * 100 for k in self.topk}
else:
return float(np.mean(self.is_correct[k])) * 100