Spaces:
Runtime error
Runtime error
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# | |
# This source code is licensed under the Apache License, Version 2.0 | |
# found in the LICENSE file in the root directory of this source tree. | |
from abc import ABCMeta, abstractmethod | |
import torch | |
from ..builder import MASK_ASSIGNERS, build_match_cost | |
try: | |
from scipy.optimize import linear_sum_assignment | |
except ImportError: | |
linear_sum_assignment = None | |
class AssignResult(metaclass=ABCMeta): | |
"""Collection of assign results.""" | |
def __init__(self, num_gts, gt_inds, labels): | |
self.num_gts = num_gts | |
self.gt_inds = gt_inds | |
self.labels = labels | |
def info(self): | |
info = { | |
"num_gts": self.num_gts, | |
"gt_inds": self.gt_inds, | |
"labels": self.labels, | |
} | |
return info | |
class BaseAssigner(metaclass=ABCMeta): | |
"""Base assigner that assigns boxes to ground truth boxes.""" | |
def assign(self, masks, gt_masks, gt_masks_ignore=None, gt_labels=None): | |
"""Assign boxes to either a ground truth boxes or a negative boxes.""" | |
pass | |
class MaskHungarianAssigner(BaseAssigner): | |
"""Computes one-to-one matching between predictions and ground truth for | |
mask. | |
This class computes an assignment between the targets and the predictions | |
based on the costs. The costs are weighted sum of three components: | |
classification cost, regression L1 cost and regression iou cost. The | |
targets don't include the no_object, so generally there are more | |
predictions than targets. After the one-to-one matching, the un-matched | |
are treated as backgrounds. Thus each query prediction will be assigned | |
with `0` or a positive integer indicating the ground truth index: | |
- 0: negative sample, no assigned gt | |
- positive integer: positive sample, index (1-based) of assigned gt | |
Args: | |
cls_cost (obj:`mmcv.ConfigDict`|dict): Classification cost config. | |
mask_cost (obj:`mmcv.ConfigDict`|dict): Mask cost config. | |
dice_cost (obj:`mmcv.ConfigDict`|dict): Dice cost config. | |
""" | |
def __init__( | |
self, | |
cls_cost=dict(type="ClassificationCost", weight=1.0), | |
dice_cost=dict(type="DiceCost", weight=1.0), | |
mask_cost=dict(type="MaskFocalCost", weight=1.0), | |
): | |
self.cls_cost = build_match_cost(cls_cost) | |
self.dice_cost = build_match_cost(dice_cost) | |
self.mask_cost = build_match_cost(mask_cost) | |
def assign(self, cls_pred, mask_pred, gt_labels, gt_masks, img_meta, gt_masks_ignore=None, eps=1e-7): | |
"""Computes one-to-one matching based on the weighted costs. | |
This method assign each query prediction to a ground truth or | |
background. The `assigned_gt_inds` with -1 means don't care, | |
0 means negative sample, and positive number is the index (1-based) | |
of assigned gt. | |
The assignment is done in the following steps, the order matters. | |
1. assign every prediction to -1 | |
2. compute the weighted costs | |
3. do Hungarian matching on CPU based on the costs | |
4. assign all to 0 (background) first, then for each matched pair | |
between predictions and gts, treat this prediction as foreground | |
and assign the corresponding gt index (plus 1) to it. | |
Args: | |
mask_pred (Tensor): Predicted mask, shape [num_query, h, w] | |
cls_pred (Tensor): Predicted classification logits, shape | |
[num_query, num_class]. | |
gt_masks (Tensor): Ground truth mask, shape [num_gt, h, w]. | |
gt_labels (Tensor): Label of `gt_masks`, shape (num_gt,). | |
img_meta (dict): Meta information for current image. | |
gt_masks_ignore (Tensor, optional): Ground truth masks that are | |
labelled as `ignored`. Default None. | |
eps (int | float, optional): A value added to the denominator for | |
numerical stability. Default 1e-7. | |
Returns: | |
:obj:`AssignResult`: The assigned result. | |
""" | |
assert gt_masks_ignore is None, "Only case when gt_masks_ignore is None is supported." | |
num_gts, num_queries = gt_labels.shape[0], cls_pred.shape[0] | |
# 1. assign -1 by default | |
assigned_gt_inds = cls_pred.new_full((num_queries,), -1, dtype=torch.long) | |
assigned_labels = cls_pred.new_full((num_queries,), -1, dtype=torch.long) | |
if num_gts == 0 or num_queries == 0: | |
# No ground truth or boxes, return empty assignment | |
if num_gts == 0: | |
# No ground truth, assign all to background | |
assigned_gt_inds[:] = 0 | |
return AssignResult(num_gts, assigned_gt_inds, labels=assigned_labels) | |
# 2. compute the weighted costs | |
# classification and maskcost. | |
if self.cls_cost.weight != 0 and cls_pred is not None: | |
cls_cost = self.cls_cost(cls_pred, gt_labels) | |
else: | |
cls_cost = 0 | |
if self.mask_cost.weight != 0: | |
# mask_pred shape = [nq, h, w] | |
# gt_mask shape = [ng, h, w] | |
# mask_cost shape = [nq, ng] | |
mask_cost = self.mask_cost(mask_pred, gt_masks) | |
else: | |
mask_cost = 0 | |
if self.dice_cost.weight != 0: | |
dice_cost = self.dice_cost(mask_pred, gt_masks) | |
else: | |
dice_cost = 0 | |
cost = cls_cost + mask_cost + dice_cost | |
# 3. do Hungarian matching on CPU using linear_sum_assignment | |
cost = cost.detach().cpu() | |
if linear_sum_assignment is None: | |
raise ImportError('Please run "pip install scipy" ' "to install scipy first.") | |
matched_row_inds, matched_col_inds = linear_sum_assignment(cost) | |
matched_row_inds = torch.from_numpy(matched_row_inds).to(cls_pred.device) | |
matched_col_inds = torch.from_numpy(matched_col_inds).to(cls_pred.device) | |
# 4. assign backgrounds and foregrounds | |
# assign all indices to backgrounds first | |
assigned_gt_inds[:] = 0 | |
# assign foregrounds based on matching results | |
assigned_gt_inds[matched_row_inds] = matched_col_inds + 1 | |
assigned_labels[matched_row_inds] = gt_labels[matched_col_inds] | |
return AssignResult(num_gts, assigned_gt_inds, labels=assigned_labels) | |