Spaces:
Runtime error
Runtime error
File size: 7,668 Bytes
b34d1d6 |
|
import os
import mmcv
import torch
from mmengine.dist import broadcast_object_list, collect_results, is_main_process
from typing import Dict, Optional, Sequence
from mmengine.evaluator import BaseMetric
from mmdet.registry import METRICS
from mmengine.evaluator.metric import _to_cpu
from mmengine.visualization import Visualizer
@METRICS.register_module()
class InsClsIoUMetric(BaseMetric):
def __init__(self,
collect_device: str = 'cpu',
prefix: Optional[str] = None,
base_classes=None,
novel_classes=None,
with_score=True,
output_failure=False,
) -> None:
super().__init__(collect_device=collect_device, prefix=prefix)
self.scores = []
self.iou_list = []
self.base_scores = []
self.novel_scores = []
self.base_iou_list = []
self.novel_iou_list = []
self.with_score = with_score
if base_classes is not None:
assert novel_classes is not None
num_classes = max(max(base_classes) + 1, max(novel_classes) + 1)
self.base_novel_indicator = torch.zeros((num_classes,), dtype=torch.long)
for clss in base_classes:
self.base_novel_indicator[clss] = 1
for clss in novel_classes:
self.base_novel_indicator[clss] = 2
else:
self.base_novel_indicator = None
self.output_failure = output_failure
def get_iou(self, gt_masks, pred_masks):
gt_masks = gt_masks
n, h, w = gt_masks.shape
intersection = (gt_masks & pred_masks).reshape(n, h * w).sum(dim=-1)
union = (gt_masks | pred_masks).reshape(n, h * w).sum(dim=-1)
ious = (intersection / (union + 1.e-8))
return ious
def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None:
for data_sample in data_samples:
gt_labels = data_sample['gt_instances']['labels']
if len(gt_labels) == 0:
score = gt_labels.new_zeros((0,), dtype=torch.float)
ious = gt_labels.new_zeros((0,), dtype=torch.float)
else:
if self.with_score:
if self.base_novel_indicator is not None:
assert (self.base_novel_indicator[gt_labels.cpu()] > 0).all()
pred_labels = data_sample['pred_instances']['labels']
score = (pred_labels == gt_labels).to(dtype=torch.float) * 100
if 'masks' in data_sample['pred_instances']:
pred_masks = data_sample['pred_instances']['masks']
if self.output_failure:
for idx, _score in enumerate(score.cpu().numpy().tolist()):
if _score == 0.:
img_path = data_sample['img_path']
vis = Visualizer()
rgb_img = mmcv.imread(img_path)
rgb_img = mmcv.bgr2rgb(rgb_img)
vis.set_image(rgb_img)
masks = pred_masks[idx]
# colors = [(0, 176, 237)]
colors = [(250, 177, 135)]
vis.draw_binary_masks(masks, alphas=.85, colors=colors)
vis_res = vis.get_image()
if vis_res is None:
continue
img_name = os.path.basename(img_path)
mmcv.imwrite(
mmcv.rgb2bgr(vis_res), os.path.join(
'failure_lvis',
img_name.split('.')[0] + '_' + str(idx) + '_' + str(int(gt_labels[idx]))
+ '_' + str(int(pred_labels[idx])) + '.jpg')
)
gt_masks = data_sample['gt_instances']['masks']
gt_masks = gt_masks.to_tensor(dtype=torch.bool, device=pred_masks.device)
ious = self.get_iou(gt_masks, pred_masks)
else:
ious = gt_labels.new_tensor([0.])
self.iou_list.append(ious.to(device='cpu'))
if self.base_novel_indicator is not None:
self.base_iou_list.append(ious[self.base_novel_indicator[gt_labels.cpu()] == 1].to(device='cpu'))
self.novel_iou_list.append(ious[self.base_novel_indicator[gt_labels.cpu()] == 2].to(device='cpu'))
if self.with_score:
self.scores.append(score.to(device='cpu'))
if self.base_novel_indicator is not None:
self.base_scores.append(score[self.base_novel_indicator[gt_labels.cpu()] == 1].to(device='cpu'))
self.novel_scores.append(score[self.base_novel_indicator[gt_labels.cpu()] == 2].to(device='cpu'))
def compute_metrics(self, scores, ious,
base_scores, base_ious,
novel_scores, novel_ious) -> Dict[str, float]:
iou = ious.mean().item()
results = dict()
results['miou'] = iou
if self.base_novel_indicator is not None:
results['base_iou'] = base_ious.mean().item()
results['novel_iou'] = novel_ious.mean().item()
if self.with_score:
score = scores.mean().item()
results['score'] = score
if base_scores is not None:
results['base_score'] = base_scores.mean().item()
results['novel_score'] = novel_scores.mean().item()
return results
def evaluate(self, size: int) -> dict:
_ious = collect_results(self.iou_list, size, self.collect_device)
if self.base_novel_indicator is not None:
_base_ious = collect_results(self.base_iou_list, size, self.collect_device)
_novel_ious = collect_results(self.novel_iou_list, size, self.collect_device)
if self.with_score:
_scores = collect_results(self.scores, size, self.collect_device)
if self.base_novel_indicator is not None:
_base_scores = collect_results(self.base_scores, size, self.collect_device)
_novel_scores = collect_results(self.novel_scores, size, self.collect_device)
if is_main_process():
if self.base_novel_indicator is not None:
base_ious = torch.cat(_base_ious)
novel_ious = torch.cat(_novel_ious)
else:
base_ious = None
novel_ious = None
if self.with_score:
scores = torch.cat(_scores)
scores = _to_cpu(scores)
if self.base_novel_indicator is not None:
base_scores = torch.cat(_base_scores)
novel_scores = torch.cat(_novel_scores)
else:
base_scores = None
novel_scores = None
else:
scores = None
base_scores = None
novel_scores = None
ious = torch.cat(_ious)
ious = _to_cpu(ious)
_metrics = self.compute_metrics(
scores, ious,
base_scores, base_ious,
novel_scores, novel_ious
)
metrics = [_metrics]
else:
metrics = [None] # type: ignore
broadcast_object_list(metrics)
return metrics[0]
|