File size: 8,503 Bytes
74e8f2f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 |
# Copyright 2024 Big Vision Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Object detection reward from "Tuning computer vision models with task rewards" (https://arxiv.org/abs/2302.08242).
The `reward_fn` computes the reward for a batch of predictions and ground truth
annotations. When using it to optimize a model that outputs a prediction as a
sequence of tokens like [y0, x0, Y0, X0, class0, confidence0, y1, x1, Y1, ...]
the training loop may look like:
```
# Settings used in the paper.
config.max_level = 1000 # Coordinates are discretized into 1000 buckets.
config.max_conf = 2 # Two tokens are reserved to represent confidence.
config.num_cls = 80 # Number of classes in COCO.
config.nms_w = 0.3 # Weight for duplicate instances.
config.cls_smooth = 0.05 # Adjust the classes weights based on their frequency.
config.reward_thr = (0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95)
config.correct_thr = 0.5 # Learn the IoU when matching with threshold=0.5.
config.conf_w = 0.3 # Weight for the confidence loss.
# 1) Sample N outputs for each input and compute rewards, use one sample to
# optimize and others to compute a reward baseline.
sample_seqs = sample_fn(params, images, num_samples)
sample_rewards, aux = reward_fn(sample_seqs, labels, config)
labels = sample_seqs[:, 0, ...]
rewards = sample_rewards[:, 0]
match_iou = aux["match_iou"][:, 0]
baselines = (jnp.sum(sample_rewards, axis=-1) - rewards) / (num_samples - 1)
# 2) Optimizize the model. By using REINFORCE to adjust the likelihood of the
# sequence based on the reward and with supervision to teach the model to
# predict the expected IoU of each box in its own samples.
def loss_fn(params):
logits = model.apply(params, images, labels, train=True, rngs=rngs)
logits_softmax = jax.nn.log_softmax(logits)
# Use reinforce to optimize the expected reward for the whole sequence.
seq_rewards = (rewards - baselines)
# Note: consider improve this code to skip this loss for confidence tokens.
# The paper did not do it due to a bug (and also does not seem to matter).
target = jax.nn.one_hot(labels, logits.shape[-1]) * seq_rewards[:, None, None]
loss_reward = -jnp.sum(target * logits_softmax, axis=-1)
# Use supervision loss to tune the confidence tokens to predict IoU:
# - (1.0, 0.0, 0.0, ...) -> for padded boxes.
# - (0.0, 1-iou, iou, ...) -> for sampled boxes.
conf0 = (labels[:, 5::6] == 0)
conf1 = (labels[:, 5::6] > 0) * (1.0 - match_iou)
conf2 = (labels[:, 5::6] > 0) * match_iou
target_conf = jnp.stack([conf0, conf1, conf2], axis=-1)
logits_conf = logits_softmax[:, 5::6, :3]
loss_conf = -jnp.sum(target_conf * logits_conf, axis=-1)
loss = jnp.mean(loss_reward) + config.conf_w * jnp.mean(loss_conf)
return loss
```
"""
import functools
import einops
import jax
import jax.numpy as jnp
# Frequency of COCO object detection classes as observed in the training set.
# pylint: disable=bad-whitespace,bad-continuation
CLS_COUNTS = [
262465, 7113, 43867, 8725, 5135, 6069, 4571, 9973, 10759,
12884, 1865, 1983, 1285, 9838, 10806, 4768, 5508, 6587,
9509, 8147, 5513, 1294, 5303, 5131, 8720, 11431, 12354,
6496, 6192, 2682, 6646, 2685, 6347, 9076, 3276, 3747,
5543, 6126, 4812, 24342, 7913, 20650, 5479, 7770, 6165,
14358, 9458, 5851, 4373, 6399, 7308, 7852, 2918, 5821,
7179, 6353, 38491, 5779, 8652, 4192, 15714, 4157, 5805,
4970, 2262, 5703, 2855, 6434, 1673, 3334, 225, 5610,
2637, 24715, 6334, 6613, 1481, 4793, 198, 1954
]
# pylint: enable=bad-whitespace,bad-continuation
def seq2box(seq, max_level, max_conf, num_cls):
"""Extract boxes encoded as sequences."""
# Reshape to instances of boxes
dim_per_box = 6
seq_len = seq.shape[-1]
seq = seq[..., :(seq_len - seq_len % dim_per_box)]
seq = einops.rearrange(seq, "... (n d) -> ... n d", d=dim_per_box)
# Unpack box fields
boxes, labels, confs = seq[..., 0:4], seq[..., 4], seq[..., 5]
boxes = boxes - max_conf - 1
labels = labels - max_conf - 1 - max_level - 1
boxes = jnp.clip(boxes, 0, max_level) / max_level
labels = jnp.clip(labels, 0, num_cls - 1)
confs = jnp.clip(confs, 0, max_conf)
return boxes, labels, confs
def iou_fn(box1, box2):
"""Compute IoU of two boxes."""
ymin1, xmin1, ymax1, xmax1 = box1
ymin2, xmin2, ymax2, xmax2 = box2
a1 = jnp.abs((ymax1 - ymin1) * (xmax1 - xmin1))
a2 = jnp.abs((ymax2 - ymin2) * (xmax2 - xmin2))
yl = jnp.maximum(ymin1, ymin2)
yr = jnp.minimum(ymax1, ymax2)
yi = jnp.maximum(0, yr - yl)
xl = jnp.maximum(xmin1, xmin2)
xr = jnp.minimum(xmax1, xmax2)
xi = jnp.maximum(0, xr - xl)
inter = xi * yi
return inter / (a1 + a2 - inter + 1e-9)
iou_fn_batched = jax.vmap(
jax.vmap(iou_fn, in_axes=(None, 0)), in_axes=(0, None)
)
def _reward_fn_thr(seq_pred, seq_gt,
thr, nms_w, max_level, max_conf, num_cls, cls_smooth):
"""Compute detection reward function for a given IoU threshold."""
# Weight matches of each label inversely proportional to the percentage of
# GT instances with such label in the whole train dataset. Additionally
# smooth out the observed distribution.
cls_counts = jnp.array(CLS_COUNTS)
weights = 1.0 / (cls_counts + cls_smooth*jnp.sum(cls_counts))
weights = num_cls * weights / jnp.sum(weights)
boxes_pred, labels_pred, confs_pred = seq2box(
seq_pred, max_level, max_conf, num_cls)
boxes_gt, labels_gt, confs_gt = seq2box(
seq_gt, max_level, max_conf, num_cls)
# Compute IoU matrix: Predictions X GT
iou = iou_fn_batched(boxes_pred, boxes_gt)
# IoU thr
iou = jnp.where(iou > thr, iou, 0.0)
# EOS mask
confs_mask = (confs_pred[:, None] > 0) * (confs_gt[None, :] > 0)
iou = confs_mask * iou
# Label mask
label_mask = labels_pred[:, None] == labels_gt[None, :]
iou = label_mask * iou
# Each prediction is matched to a single box
single_match_mask = jax.nn.one_hot(jnp.argmax(iou, axis=1), iou.shape[1])
iou = iou * single_match_mask
# Pred. boxes indicators
correct = jnp.any(iou > 0.0, axis=1).astype("int32") + 1
correct = jnp.where(confs_pred > 0, correct, 0)
# For each GT box find best match
matches_idx = jnp.argmax(iou, axis=0)
matches_iou = jnp.take_along_axis(iou, matches_idx[None], axis=0)[0]
matches_idx = jnp.where(matches_iou > 0.0, matches_idx, -1)
match_reward = jnp.sum((matches_idx >= 0) * weights[labels_gt][None, :])
# Compute duplicate penalty (aka NMS).
matches_mask = jax.nn.one_hot(matches_idx, iou.shape[0], axis=0)
nms_penalty = jnp.sum(
(iou > 0.0) * (1 - matches_mask) * weights[labels_pred][:, None])
match_iou = jnp.sum(iou, axis=1)
return {
"reward": (match_reward - nms_w * nms_penalty),
"num_matches": jnp.sum(matches_idx >= 0),
"nms_penalty": nms_penalty,
"correct": correct,
"match_iou": match_iou,
}
def reward_fn(seqs_pred, seqs_gt, config):
"""Total reward."""
result = {}
thrs = config.reward_thr
correct_thr = config.correct_thr
r_keys = ["reward", "num_matches", "nms_penalty"]
for thr in thrs:
fn = functools.partial(
_reward_fn_thr,
thr=thr,
nms_w=config.nms_w,
max_level=config.max_level,
max_conf=config.max_conf,
num_cls=config.num_cls,
cls_smooth=config.cls_smooth,
)
rewards = jax.vmap(jax.vmap(fn, in_axes=(0, None)))(seqs_pred, seqs_gt)
result = {**result, **{f"{k}-{thr:0.1f}": rewards[k]
for k in r_keys}}
if thr == correct_thr:
correct = rewards["correct"]
match_iou = rewards["match_iou"]
result = {
**result,
**{k: jnp.mean(
jnp.array([result[f"{k}-{thr:0.1f}"] for thr in thrs]), axis=0)
for k in r_keys}
}
return result["reward"], {
"result": result,
"correct": correct,
"match_iou": match_iou,
}
|