Spaces:
Runtime error
Runtime error
File size: 7,668 Bytes
b34d1d6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 |
import os
import mmcv
import torch
from mmengine.dist import broadcast_object_list, collect_results, is_main_process
from typing import Dict, Optional, Sequence
from mmengine.evaluator import BaseMetric
from mmdet.registry import METRICS
from mmengine.evaluator.metric import _to_cpu
from mmengine.visualization import Visualizer
@METRICS.register_module()
class InsClsIoUMetric(BaseMetric):
def __init__(self,
collect_device: str = 'cpu',
prefix: Optional[str] = None,
base_classes=None,
novel_classes=None,
with_score=True,
output_failure=False,
) -> None:
super().__init__(collect_device=collect_device, prefix=prefix)
self.scores = []
self.iou_list = []
self.base_scores = []
self.novel_scores = []
self.base_iou_list = []
self.novel_iou_list = []
self.with_score = with_score
if base_classes is not None:
assert novel_classes is not None
num_classes = max(max(base_classes) + 1, max(novel_classes) + 1)
self.base_novel_indicator = torch.zeros((num_classes,), dtype=torch.long)
for clss in base_classes:
self.base_novel_indicator[clss] = 1
for clss in novel_classes:
self.base_novel_indicator[clss] = 2
else:
self.base_novel_indicator = None
self.output_failure = output_failure
def get_iou(self, gt_masks, pred_masks):
gt_masks = gt_masks
n, h, w = gt_masks.shape
intersection = (gt_masks & pred_masks).reshape(n, h * w).sum(dim=-1)
union = (gt_masks | pred_masks).reshape(n, h * w).sum(dim=-1)
ious = (intersection / (union + 1.e-8))
return ious
def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None:
for data_sample in data_samples:
gt_labels = data_sample['gt_instances']['labels']
if len(gt_labels) == 0:
score = gt_labels.new_zeros((0,), dtype=torch.float)
ious = gt_labels.new_zeros((0,), dtype=torch.float)
else:
if self.with_score:
if self.base_novel_indicator is not None:
assert (self.base_novel_indicator[gt_labels.cpu()] > 0).all()
pred_labels = data_sample['pred_instances']['labels']
score = (pred_labels == gt_labels).to(dtype=torch.float) * 100
if 'masks' in data_sample['pred_instances']:
pred_masks = data_sample['pred_instances']['masks']
if self.output_failure:
for idx, _score in enumerate(score.cpu().numpy().tolist()):
if _score == 0.:
img_path = data_sample['img_path']
vis = Visualizer()
rgb_img = mmcv.imread(img_path)
rgb_img = mmcv.bgr2rgb(rgb_img)
vis.set_image(rgb_img)
masks = pred_masks[idx]
# colors = [(0, 176, 237)]
colors = [(250, 177, 135)]
vis.draw_binary_masks(masks, alphas=.85, colors=colors)
vis_res = vis.get_image()
if vis_res is None:
continue
img_name = os.path.basename(img_path)
mmcv.imwrite(
mmcv.rgb2bgr(vis_res), os.path.join(
'failure_lvis',
img_name.split('.')[0] + '_' + str(idx) + '_' + str(int(gt_labels[idx]))
+ '_' + str(int(pred_labels[idx])) + '.jpg')
)
gt_masks = data_sample['gt_instances']['masks']
gt_masks = gt_masks.to_tensor(dtype=torch.bool, device=pred_masks.device)
ious = self.get_iou(gt_masks, pred_masks)
else:
ious = gt_labels.new_tensor([0.])
self.iou_list.append(ious.to(device='cpu'))
if self.base_novel_indicator is not None:
self.base_iou_list.append(ious[self.base_novel_indicator[gt_labels.cpu()] == 1].to(device='cpu'))
self.novel_iou_list.append(ious[self.base_novel_indicator[gt_labels.cpu()] == 2].to(device='cpu'))
if self.with_score:
self.scores.append(score.to(device='cpu'))
if self.base_novel_indicator is not None:
self.base_scores.append(score[self.base_novel_indicator[gt_labels.cpu()] == 1].to(device='cpu'))
self.novel_scores.append(score[self.base_novel_indicator[gt_labels.cpu()] == 2].to(device='cpu'))
def compute_metrics(self, scores, ious,
base_scores, base_ious,
novel_scores, novel_ious) -> Dict[str, float]:
iou = ious.mean().item()
results = dict()
results['miou'] = iou
if self.base_novel_indicator is not None:
results['base_iou'] = base_ious.mean().item()
results['novel_iou'] = novel_ious.mean().item()
if self.with_score:
score = scores.mean().item()
results['score'] = score
if base_scores is not None:
results['base_score'] = base_scores.mean().item()
results['novel_score'] = novel_scores.mean().item()
return results
def evaluate(self, size: int) -> dict:
_ious = collect_results(self.iou_list, size, self.collect_device)
if self.base_novel_indicator is not None:
_base_ious = collect_results(self.base_iou_list, size, self.collect_device)
_novel_ious = collect_results(self.novel_iou_list, size, self.collect_device)
if self.with_score:
_scores = collect_results(self.scores, size, self.collect_device)
if self.base_novel_indicator is not None:
_base_scores = collect_results(self.base_scores, size, self.collect_device)
_novel_scores = collect_results(self.novel_scores, size, self.collect_device)
if is_main_process():
if self.base_novel_indicator is not None:
base_ious = torch.cat(_base_ious)
novel_ious = torch.cat(_novel_ious)
else:
base_ious = None
novel_ious = None
if self.with_score:
scores = torch.cat(_scores)
scores = _to_cpu(scores)
if self.base_novel_indicator is not None:
base_scores = torch.cat(_base_scores)
novel_scores = torch.cat(_novel_scores)
else:
base_scores = None
novel_scores = None
else:
scores = None
base_scores = None
novel_scores = None
ious = torch.cat(_ious)
ious = _to_cpu(ious)
_metrics = self.compute_metrics(
scores, ious,
base_scores, base_ious,
novel_scores, novel_ious
)
metrics = [_metrics]
else:
metrics = [None] # type: ignore
broadcast_object_list(metrics)
return metrics[0]
|