Spaces:
Sleeping
Sleeping
# Copyright 2020 Google Research. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
"""Base target assigner module. | |
The job of a TargetAssigner is, for a given set of anchors (bounding boxes) and | |
groundtruth detections (bounding boxes), to assign classification and regression | |
targets to each anchor as well as weights to each anchor (specifying, e.g., | |
which anchors should not contribute to training loss). | |
It assigns classification/regression targets by performing the following steps: | |
1) Computing pairwise similarity between anchors and groundtruth boxes using a | |
provided RegionSimilarity Calculator | |
2) Computing a matching based on the similarity matrix using a provided Matcher | |
3) Assigning regression targets based on the matching and a provided BoxCoder | |
4) Assigning classification targets based on the matching and groundtruth labels | |
Note that TargetAssigners only operate on detections from a single | |
image at a time, so any logic for applying a TargetAssigner to multiple | |
images must be handled externally. | |
""" | |
import torch | |
from typing import Optional | |
from . import box_list | |
from .region_similarity_calculator import IouSimilarity | |
from .argmax_matcher import ArgMaxMatcher | |
from .matcher import Match | |
from .box_list import BoxList | |
from .box_coder import FasterRcnnBoxCoder | |
KEYPOINTS_FIELD_NAME = 'keypoints' | |
#@torch.jit.script | |
class TargetAssigner(object): | |
"""Target assigner to compute classification and regression targets.""" | |
def __init__(self, similarity_calc: IouSimilarity, matcher: ArgMaxMatcher, box_coder: FasterRcnnBoxCoder, | |
negative_class_weight: float = 1.0, unmatched_cls_target: Optional[float] = None, | |
keypoints_field_name: str = KEYPOINTS_FIELD_NAME): | |
"""Construct Object Detection Target Assigner. | |
Args: | |
similarity_calc: a RegionSimilarityCalculator | |
matcher: Matcher used to match groundtruth to anchors. | |
box_coder: BoxCoder used to encode matching groundtruth boxes with respect to anchors. | |
negative_class_weight: classification weight to be associated to negative | |
anchors (default: 1.0). The weight must be in [0., 1.]. | |
unmatched_cls_target: a float32 tensor with shape [d_1, d_2, ..., d_k] | |
which is consistent with the classification target for each | |
anchor (and can be empty for scalar targets). This shape must thus be | |
compatible with the groundtruth labels that are passed to the "assign" | |
function (which have shape [num_gt_boxes, d_1, d_2, ..., d_k]). | |
If set to None, unmatched_cls_target is set to be [0] for each anchor. | |
Raises: | |
ValueError: if similarity_calc is not a RegionSimilarityCalculator or | |
if matcher is not a Matcher or if box_coder is not a BoxCoder | |
""" | |
self._similarity_calc = similarity_calc | |
self._matcher = matcher | |
self._box_coder = box_coder | |
self._negative_class_weight = negative_class_weight | |
if unmatched_cls_target is not None: | |
self._unmatched_cls_target = unmatched_cls_target | |
else: | |
self._unmatched_cls_target = 0. | |
self._keypoints_field_name = keypoints_field_name | |
def assign(self, anchors: BoxList, groundtruth_boxes: BoxList, groundtruth_labels=None, groundtruth_weights=None): | |
"""Assign classification and regression targets to each anchor. | |
For a given set of anchors and groundtruth detections, match anchors | |
to groundtruth_boxes and assign classification and regression targets to | |
each anchor as well as weights based on the resulting match (specifying, | |
e.g., which anchors should not contribute to training loss). | |
Anchors that are not matched to anything are given a classification target | |
of self._unmatched_cls_target which can be specified via the constructor. | |
Args: | |
anchors: a BoxList representing N anchors | |
groundtruth_boxes: a BoxList representing M groundtruth boxes | |
groundtruth_labels: a tensor of shape [M, d_1, ... d_k] | |
with labels for each of the ground_truth boxes. The subshape | |
[d_1, ... d_k] can be empty (corresponding to scalar inputs). When set | |
to None, groundtruth_labels assumes a binary problem where all | |
ground_truth boxes get a positive label (of 1). | |
groundtruth_weights: a float tensor of shape [M] indicating the weight to | |
assign to all anchors match to a particular groundtruth box. The weights | |
must be in [0., 1.]. If None, all weights are set to 1. | |
**params: Additional keyword arguments for specific implementations of the Matcher. | |
Returns: | |
cls_targets: a float32 tensor with shape [num_anchors, d_1, d_2 ... d_k], | |
where the subshape [d_1, ..., d_k] is compatible with groundtruth_labels | |
which has shape [num_gt_boxes, d_1, d_2, ... d_k]. | |
cls_weights: a float32 tensor with shape [num_anchors] | |
reg_targets: a float32 tensor with shape [num_anchors, box_code_dimension] | |
reg_weights: a float32 tensor with shape [num_anchors] | |
match: a matcher.Match object encoding the match between anchors and groundtruth boxes, | |
with rows corresponding to groundtruth boxes and columns corresponding to anchors. | |
Raises: | |
ValueError: if anchors or groundtruth_boxes are not of type box_list.BoxList | |
""" | |
if not isinstance(anchors, box_list.BoxList): | |
raise ValueError('anchors must be an BoxList') | |
if not isinstance(groundtruth_boxes, box_list.BoxList): | |
raise ValueError('groundtruth_boxes must be an BoxList') | |
# device = anchors.device() | |
# if groundtruth_labels is None: | |
# groundtruth_labels = torch.ones(groundtruth_boxes.num_boxes(), device=device).unsqueeze(0) | |
# groundtruth_labels = groundtruth_labels.unsqueeze(-1) | |
# if groundtruth_weights is None: | |
# num_gt_boxes = groundtruth_boxes.num_boxes() | |
# if not num_gt_boxes: | |
# num_gt_boxes = groundtruth_boxes.num_boxes() | |
# groundtruth_weights = torch.ones([num_gt_boxes], device=device) | |
match_quality_matrix = self._similarity_calc.compare(groundtruth_boxes, anchors) | |
match = self._matcher.match(match_quality_matrix) | |
reg_targets = self._create_regression_targets(anchors, groundtruth_boxes, match) | |
cls_targets = self._create_classification_targets(groundtruth_labels, match) | |
#reg_weights = self._create_regression_weights(match, groundtruth_weights) | |
#cls_weights = self._create_classification_weights(match, groundtruth_weights) | |
return cls_targets, reg_targets, match | |
def _create_regression_targets(self, anchors: BoxList, groundtruth_boxes: BoxList, match: Match): | |
"""Returns a regression target for each anchor. | |
Args: | |
anchors: a BoxList representing N anchors | |
groundtruth_boxes: a BoxList representing M groundtruth_boxes | |
match: a matcher.Match object | |
Returns: | |
reg_targets: a float32 tensor with shape [N, box_code_dimension] | |
""" | |
device = anchors.device() | |
zero_box = torch.zeros((1, 4), device=device) | |
matched_gt_boxes = match.gather_based_on_match( | |
groundtruth_boxes.boxes(), unmatched_value=zero_box, ignored_value=zero_box) | |
matched_gt_boxlist = box_list.BoxList(matched_gt_boxes) | |
if groundtruth_boxes.has_field(self._keypoints_field_name): | |
groundtruth_keypoints = groundtruth_boxes.get_field(self._keypoints_field_name) | |
zero_kp = torch.zeros((1,) + groundtruth_keypoints.shape[1:], device=device) | |
matched_keypoints = match.gather_based_on_match( | |
groundtruth_keypoints, unmatched_value=zero_kp, ignored_value=zero_kp) | |
matched_gt_boxlist.add_field(self._keypoints_field_name, matched_keypoints) | |
matched_reg_targets = self._box_coder.encode(matched_gt_boxlist, anchors) | |
unmatched_ignored_reg_targets = self._default_regression_target(device).repeat(match.match_results.shape[0], 1) | |
matched_anchors_mask = match.matched_column_indicator() | |
reg_targets = torch.where(matched_anchors_mask.unsqueeze(1), matched_reg_targets, unmatched_ignored_reg_targets) | |
return reg_targets | |
def _default_regression_target(self, device: torch.device): | |
"""Returns the default target for anchors to regress to. | |
Default regression targets are set to zero (though in this implementation what | |
these targets are set to should not matter as the regression weight of any box | |
set to regress to the default target is zero). | |
Returns: | |
default_target: a float32 tensor with shape [1, box_code_dimension] | |
""" | |
return torch.zeros(1, self._box_coder.code_size(), device=device) | |
def _create_classification_targets(self, groundtruth_labels, match: Match): | |
"""Create classification targets for each anchor. | |
Assign a classification target of for each anchor to the matching | |
groundtruth label that is provided by match. Anchors that are not matched | |
to anything are given the target self._unmatched_cls_target | |
Args: | |
groundtruth_labels: a tensor of shape [num_gt_boxes, d_1, ... d_k] | |
with labels for each of the ground_truth boxes. The subshape | |
[d_1, ... d_k] can be empty (corresponding to scalar labels). | |
match: a matcher.Match object that provides a matching between anchors | |
and groundtruth boxes. | |
Returns: | |
a float32 tensor with shape [num_anchors, d_1, d_2 ... d_k], where the | |
subshape [d_1, ..., d_k] is compatible with groundtruth_labels which has | |
shape [num_gt_boxes, d_1, d_2, ... d_k]. | |
""" | |
return match.gather_based_on_match( | |
groundtruth_labels, | |
unmatched_value=self._unmatched_cls_target, ignored_value=self._unmatched_cls_target) | |
def _create_regression_weights(self, match: Match, groundtruth_weights): | |
"""Set regression weight for each anchor. | |
Only positive anchors are set to contribute to the regression loss, so this | |
method returns a weight of 1 for every positive anchor and 0 for every | |
negative anchor. | |
Args: | |
match: a matcher.Match object that provides a matching between anchors and groundtruth boxes. | |
groundtruth_weights: a float tensor of shape [M] indicating the weight to | |
assign to all anchors match to a particular groundtruth box. | |
Returns: | |
a float32 tensor with shape [num_anchors] representing regression weights. | |
""" | |
return match.gather_based_on_match(groundtruth_weights, ignored_value=0., unmatched_value=0.) | |
def _create_classification_weights(self, match: Match, groundtruth_weights): | |
"""Create classification weights for each anchor. | |
Positive (matched) anchors are associated with a weight of | |
positive_class_weight and negative (unmatched) anchors are associated with | |
a weight of negative_class_weight. When anchors are ignored, weights are set | |
to zero. By default, both positive/negative weights are set to 1.0, | |
but they can be adjusted to handle class imbalance (which is almost always | |
the case in object detection). | |
Args: | |
match: a matcher.Match object that provides a matching between anchors and groundtruth boxes. | |
groundtruth_weights: a float tensor of shape [M] indicating the weight to | |
assign to all anchors match to a particular groundtruth box. | |
Returns: | |
a float32 tensor with shape [num_anchors] representing classification weights. | |
""" | |
return match.gather_based_on_match( | |
groundtruth_weights, ignored_value=0., unmatched_value=self._negative_class_weight) | |
def box_coder(self): | |
"""Get BoxCoder of this TargetAssigner. | |
Returns: | |
BoxCoder object. | |
""" | |
return self._box_coder | |