Spaces:
Build error
Build error
# Copyright (c) Facebook, Inc. and its affiliates. | |
import logging | |
import os | |
import numpy as np | |
import torch | |
from torchvision.ops import box_iou | |
from detectron2.structures import BoxMode | |
from detectron2.data import MetadataCatalog | |
from detectron2.utils.comm import all_gather, gather, is_main_process, synchronize | |
from detectron2.evaluation.evaluator import DatasetEvaluator | |
class InteractiveEvaluator(DatasetEvaluator): | |
""" | |
Evaluate point interactive IoU metrics. | |
""" | |
def __init__( | |
self, | |
dataset_name, | |
output_dir, | |
max_clicks=20, | |
iou_iter=1, | |
compute_box=False, | |
distributed=True, | |
): | |
self._logger = logging.getLogger(__name__) | |
self._dataset_name = dataset_name | |
self._distributed = distributed | |
self._cpu_device = torch.device("cpu") | |
self._output_dir = output_dir | |
self.max_clicks = max_clicks | |
self.iou_iter = iou_iter | |
meta = MetadataCatalog.get(dataset_name) | |
def reset(self): | |
self.iou_list = [] | |
self.num_samples = 0 | |
self.all_ious = [0.5, 0.8, 0.85, 0.9] | |
def process(self, inputs, outputs): | |
self.iou_list += [o['mask_iou'] for o in outputs] | |
self.num_samples += len(outputs) | |
def compute_noc(self): | |
def _get_noc(iou_arr, iou_thr): | |
vals = iou_arr >= iou_thr | |
return vals.max(dim=0)[1].item() + 1 if vals.any() else self.max_clicks | |
noc_list = {} | |
for iou_thr in self.all_ious: | |
scores_arr = [_get_noc(iou_arr, iou_thr) for iou_arr in self.iou_list] | |
noc_list[str(iou_thr)] = scores_arr | |
iou_before_max_iter = torch.stack(self.iou_list)[:,self.iou_iter-1] | |
noc_list_sum = {key:sum(value)*1.0 for key, value in noc_list.items()} | |
if self._distributed: | |
num_samples = sum(all_gather(self.num_samples)) | |
noc_list_sum_gather = all_gather(noc_list_sum) | |
iou_before_max_gather = all_gather(iou_before_max_iter.sum().cpu()) | |
noc_list_sum = {key: 0 for key in noc_list_sum_gather[0]} | |
for nlg in noc_list_sum_gather: | |
for key, value in nlg.items(): | |
noc_list_sum[key] += value | |
pred_noc = {} | |
if self._distributed and (not is_main_process()): | |
return pred_noc | |
for key, value in noc_list_sum.items(): | |
pred_noc[key] = value / num_samples | |
pred_noc['iou_max_iter'] = sum([x.item() for x in iou_before_max_gather]) / num_samples | |
return pred_noc | |
def evaluate(self): | |
pred_noc = self.compute_noc() | |
if self._distributed and (not is_main_process()): | |
return | |
def draw_iou_curve(iou_list, save_dir): | |
iou_list = torch.stack(iou_list, dim=0) | |
iou_list = iou_list.mean(dim=0).cpu().numpy() | |
# draw iou curve, with x-axis as number of clicks, y-axis as iou using matplotlib | |
import matplotlib.pyplot as plt | |
plt.figure() | |
plt.plot(range(1, self.max_clicks+1), iou_list) | |
plt.xlabel('Number of clicks') | |
plt.ylabel('IoU') | |
# create directory if not exist | |
import os | |
output_dir = os.path.join(save_dir, 'iou_by_clicks') | |
if not os.path.exists(output_dir): | |
os.makedirs(output_dir) | |
# get current time and format in 10 digits | |
import time | |
current_time = time.time() | |
current_time = int(current_time) | |
current_time = str(current_time) | |
# save iou curve | |
plt.savefig(os.path.join(output_dir, '{}.png'.format(current_time))) | |
draw_iou_curve(self.iou_list, self._output_dir) | |
results = {} | |
for idx in range(len(self.all_ious)): | |
result_str = 'noc@{}'.format(self.all_ious[idx]) | |
results[result_str] = pred_noc[str(self.all_ious[idx])] | |
results['miou@iter{}'.format(self.iou_iter)] = pred_noc['iou_max_iter'] | |
self._logger.info(results) | |
return {'interactive': results} |