Spaces:
Sleeping
Sleeping
File size: 9,023 Bytes
5672777 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 |
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Argmax matcher implementation.
This class takes a similarity matrix and matches columns to rows based on the
maximum value per column. One can specify matched_thresholds and
to prevent columns from matching to rows (generally resulting in a negative
training example) and unmatched_theshold to ignore the match (generally
resulting in neither a positive or negative training example).
This matcher is used in Fast(er)-RCNN.
Note: matchers are used in TargetAssigners. There is a create_target_assigner
factory function for popular implementations.
"""
import tensorflow as tf, tf_keras
from official.vision.utils.object_detection import matcher
from official.vision.utils.object_detection import shape_utils
class ArgMaxMatcher(matcher.Matcher):
"""Matcher based on highest value.
This class computes matches from a similarity matrix. Each column is matched
to a single row.
To support object detection target assignment this class enables setting both
matched_threshold (upper threshold) and unmatched_threshold (lower threshold)
defining three categories of similarity which define whether examples are
positive, negative, or ignored:
(1) similarity >= matched_threshold: Highest similarity. Matched/Positive!
(2) matched_threshold > similarity >= unmatched_threshold: Medium similarity.
Depending on negatives_lower_than_unmatched, this is either
Unmatched/Negative OR Ignore.
(3) unmatched_threshold > similarity: Lowest similarity. Depending on flag
negatives_lower_than_unmatched, either Unmatched/Negative or Ignore.
For ignored matches this class sets the values in the Match object to -2.
"""
def __init__(self,
matched_threshold,
unmatched_threshold=None,
negatives_lower_than_unmatched=True,
force_match_for_each_row=False):
"""Construct ArgMaxMatcher.
Args:
matched_threshold: Threshold for positive matches. Positive if
sim >= matched_threshold, where sim is the maximum value of the
similarity matrix for a given column. Set to None for no threshold.
unmatched_threshold: Threshold for negative matches. Negative if
sim < unmatched_threshold. Defaults to matched_threshold
when set to None.
negatives_lower_than_unmatched: Boolean which defaults to True. If True
then negative matches are the ones below the unmatched_threshold,
whereas ignored matches are in between the matched and unmatched
threshold. If False, then negative matches are in between the matched
and unmatched threshold, and everything lower than unmatched is ignored.
force_match_for_each_row: If True, ensures that each row is matched to
at least one column (which is not guaranteed otherwise if the
matched_threshold is high). Defaults to False. See
argmax_matcher_test.testMatcherForceMatch() for an example.
Raises:
ValueError: if unmatched_threshold is set but matched_threshold is not set
or if unmatched_threshold > matched_threshold.
"""
if (matched_threshold is None) and (unmatched_threshold is not None):
raise ValueError('Need to also define matched_threshold when'
'unmatched_threshold is defined')
self._matched_threshold = matched_threshold
if unmatched_threshold is None:
self._unmatched_threshold = matched_threshold
else:
if unmatched_threshold > matched_threshold:
raise ValueError('unmatched_threshold needs to be smaller or equal'
'to matched_threshold')
self._unmatched_threshold = unmatched_threshold
if not negatives_lower_than_unmatched:
if self._unmatched_threshold == self._matched_threshold:
raise ValueError('When negatives are in between matched and '
'unmatched thresholds, these cannot be of equal '
'value. matched: %s, unmatched: %s',
self._matched_threshold, self._unmatched_threshold)
self._force_match_for_each_row = force_match_for_each_row
self._negatives_lower_than_unmatched = negatives_lower_than_unmatched
def _match(self, similarity_matrix):
"""Tries to match each column of the similarity matrix to a row.
Args:
similarity_matrix: tensor of shape [N, M] representing any similarity
metric.
Returns:
Match object with corresponding matches for each of M columns.
"""
def _match_when_rows_are_empty():
"""Performs matching when the rows of similarity matrix are empty.
When the rows are empty, all detections are false positives. So we return
a tensor of -1's to indicate that the columns do not match to any rows.
Returns:
matches: int32 tensor indicating the row each column matches to.
"""
similarity_matrix_shape = shape_utils.combined_static_and_dynamic_shape(
similarity_matrix)
return -1 * tf.ones([similarity_matrix_shape[1]], dtype=tf.int32)
def _match_when_rows_are_non_empty():
"""Performs matching when the rows of similarity matrix are non-empty.
Returns:
matches: int32 tensor indicating the row each column matches to.
"""
# Matches for each column.
matches = tf.argmax(input=similarity_matrix, axis=0, output_type=tf.int32)
# Deal with matched and unmatched threshold.
if self._matched_threshold is not None:
# Get logical indices of ignored and unmatched columns as tf.int64
matched_vals = tf.reduce_max(input_tensor=similarity_matrix, axis=0)
below_unmatched_threshold = tf.greater(self._unmatched_threshold,
matched_vals)
between_thresholds = tf.logical_and(
tf.greater_equal(matched_vals, self._unmatched_threshold),
tf.greater(self._matched_threshold, matched_vals))
if self._negatives_lower_than_unmatched:
matches = self._set_values_using_indicator(matches,
below_unmatched_threshold,
-1)
matches = self._set_values_using_indicator(matches,
between_thresholds,
-2)
else:
matches = self._set_values_using_indicator(matches,
below_unmatched_threshold,
-2)
matches = self._set_values_using_indicator(matches,
between_thresholds,
-1)
if self._force_match_for_each_row:
similarity_matrix_shape = shape_utils.combined_static_and_dynamic_shape(
similarity_matrix)
force_match_column_ids = tf.argmax(
input=similarity_matrix, axis=1, output_type=tf.int32)
force_match_column_indicators = tf.one_hot(
force_match_column_ids, depth=similarity_matrix_shape[1])
force_match_row_ids = tf.argmax(
input=force_match_column_indicators, axis=0, output_type=tf.int32)
force_match_column_mask = tf.cast(
tf.reduce_max(input_tensor=force_match_column_indicators, axis=0),
tf.bool)
final_matches = tf.where(force_match_column_mask, force_match_row_ids,
matches)
return final_matches
else:
return matches
if similarity_matrix.shape.is_fully_defined():
if similarity_matrix.shape.dims[0].value == 0:
return _match_when_rows_are_empty()
else:
return _match_when_rows_are_non_empty()
else:
return tf.cond(
pred=tf.greater(tf.shape(input=similarity_matrix)[0], 0),
true_fn=_match_when_rows_are_non_empty,
false_fn=_match_when_rows_are_empty)
def _set_values_using_indicator(self, x, indicator, val):
"""Set the indicated fields of x to val.
Args:
x: tensor.
indicator: boolean with same shape as x.
val: scalar with value to set.
Returns:
modified tensor.
"""
indicator = tf.cast(indicator, x.dtype)
return tf.add(tf.multiply(x, 1 - indicator), val * indicator)
|