Spaces:
Runtime error
Runtime error
import numpy as np | |
import torch | |
from isegm.utils import misc | |
class TrainMetric(object): | |
def __init__(self, pred_outputs, gt_outputs): | |
self.pred_outputs = pred_outputs | |
self.gt_outputs = gt_outputs | |
def update(self, *args, **kwargs): | |
raise NotImplementedError | |
def get_epoch_value(self): | |
raise NotImplementedError | |
def reset_epoch_stats(self): | |
raise NotImplementedError | |
def log_states(self, sw, tag_prefix, global_step): | |
pass | |
def name(self): | |
return type(self).__name__ | |
class AdaptiveIoU(TrainMetric): | |
def __init__( | |
self, | |
init_thresh=0.4, | |
thresh_step=0.025, | |
thresh_beta=0.99, | |
iou_beta=0.9, | |
ignore_label=-1, | |
from_logits=True, | |
pred_output="instances", | |
gt_output="instances", | |
): | |
super().__init__(pred_outputs=(pred_output,), gt_outputs=(gt_output,)) | |
self._ignore_label = ignore_label | |
self._from_logits = from_logits | |
self._iou_thresh = init_thresh | |
self._thresh_step = thresh_step | |
self._thresh_beta = thresh_beta | |
self._iou_beta = iou_beta | |
self._ema_iou = 0.0 | |
self._epoch_iou_sum = 0.0 | |
self._epoch_batch_count = 0 | |
def update(self, pred, gt): | |
gt_mask = gt > 0.5 | |
if self._from_logits: | |
pred = torch.sigmoid(pred) | |
gt_mask_area = torch.sum(gt_mask, dim=(1, 2)).detach().cpu().numpy() | |
if np.all(gt_mask_area == 0): | |
return | |
ignore_mask = gt == self._ignore_label | |
max_iou = _compute_iou(pred > self._iou_thresh, gt_mask, ignore_mask).mean() | |
best_thresh = self._iou_thresh | |
for t in [best_thresh - self._thresh_step, best_thresh + self._thresh_step]: | |
temp_iou = _compute_iou(pred > t, gt_mask, ignore_mask).mean() | |
if temp_iou > max_iou: | |
max_iou = temp_iou | |
best_thresh = t | |
self._iou_thresh = ( | |
self._thresh_beta * self._iou_thresh + (1 - self._thresh_beta) * best_thresh | |
) | |
self._ema_iou = self._iou_beta * self._ema_iou + (1 - self._iou_beta) * max_iou | |
self._epoch_iou_sum += max_iou | |
self._epoch_batch_count += 1 | |
def get_epoch_value(self): | |
if self._epoch_batch_count > 0: | |
return self._epoch_iou_sum / self._epoch_batch_count | |
else: | |
return 0.0 | |
def reset_epoch_stats(self): | |
self._epoch_iou_sum = 0.0 | |
self._epoch_batch_count = 0 | |
def log_states(self, sw, tag_prefix, global_step): | |
sw.add_scalar( | |
tag=tag_prefix + "_ema_iou", value=self._ema_iou, global_step=global_step | |
) | |
sw.add_scalar( | |
tag=tag_prefix + "_iou_thresh", | |
value=self._iou_thresh, | |
global_step=global_step, | |
) | |
def iou_thresh(self): | |
return self._iou_thresh | |
def _compute_iou(pred_mask, gt_mask, ignore_mask=None, keep_ignore=False): | |
if ignore_mask is not None: | |
pred_mask = torch.where(ignore_mask, torch.zeros_like(pred_mask), pred_mask) | |
reduction_dims = misc.get_dims_with_exclusion(gt_mask.dim(), 0) | |
union = ( | |
torch.mean((pred_mask | gt_mask).float(), dim=reduction_dims) | |
.detach() | |
.cpu() | |
.numpy() | |
) | |
intersection = ( | |
torch.mean((pred_mask & gt_mask).float(), dim=reduction_dims) | |
.detach() | |
.cpu() | |
.numpy() | |
) | |
nonzero = union > 0 | |
iou = intersection[nonzero] / union[nonzero] | |
if not keep_ignore: | |
return iou | |
else: | |
result = np.full_like(intersection, -1) | |
result[nonzero] = iou | |
return result | |