Spaces:
Runtime error
Runtime error
File size: 14,478 Bytes
f549064 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 |
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import torch
import torch.nn.functional as F
from mmdet.registry import TASK_UTILS
from mmdet.structures.bbox import HorizontalBoxes, bbox_rescale, get_box_tensor
from .base_bbox_coder import BaseBBoxCoder
@TASK_UTILS.register_module()
class BucketingBBoxCoder(BaseBBoxCoder):
"""Bucketing BBox Coder for Side-Aware Boundary Localization (SABL).
Boundary Localization with Bucketing and Bucketing Guided Rescoring
are implemented here.
Please refer to https://arxiv.org/abs/1912.04260 for more details.
Args:
num_buckets (int): Number of buckets.
scale_factor (int): Scale factor of proposals to generate buckets.
offset_topk (int): Topk buckets are used to generate
bucket fine regression targets. Defaults to 2.
offset_upperbound (float): Offset upperbound to generate
bucket fine regression targets.
To avoid too large offset displacements. Defaults to 1.0.
cls_ignore_neighbor (bool): Ignore second nearest bucket or Not.
Defaults to True.
clip_border (bool, optional): Whether clip the objects outside the
border of the image. Defaults to True.
"""
def __init__(self,
num_buckets,
scale_factor,
offset_topk=2,
offset_upperbound=1.0,
cls_ignore_neighbor=True,
clip_border=True,
**kwargs):
super().__init__(**kwargs)
self.num_buckets = num_buckets
self.scale_factor = scale_factor
self.offset_topk = offset_topk
self.offset_upperbound = offset_upperbound
self.cls_ignore_neighbor = cls_ignore_neighbor
self.clip_border = clip_border
def encode(self, bboxes, gt_bboxes):
"""Get bucketing estimation and fine regression targets during
training.
Args:
bboxes (torch.Tensor or :obj:`BaseBoxes`): source boxes,
e.g., object proposals.
gt_bboxes (torch.Tensor or :obj:`BaseBoxes`): target of the
transformation, e.g., ground truth boxes.
Returns:
encoded_bboxes(tuple[Tensor]): bucketing estimation
and fine regression targets and weights
"""
bboxes = get_box_tensor(bboxes)
gt_bboxes = get_box_tensor(gt_bboxes)
assert bboxes.size(0) == gt_bboxes.size(0)
assert bboxes.size(-1) == gt_bboxes.size(-1) == 4
encoded_bboxes = bbox2bucket(bboxes, gt_bboxes, self.num_buckets,
self.scale_factor, self.offset_topk,
self.offset_upperbound,
self.cls_ignore_neighbor)
return encoded_bboxes
def decode(self, bboxes, pred_bboxes, max_shape=None):
"""Apply transformation `pred_bboxes` to `boxes`.
Args:
boxes (torch.Tensor or :obj:`BaseBoxes`): Basic boxes.
pred_bboxes (torch.Tensor): Predictions for bucketing estimation
and fine regression
max_shape (tuple[int], optional): Maximum shape of boxes.
Defaults to None.
Returns:
Union[torch.Tensor, :obj:`BaseBoxes`]: Decoded boxes.
"""
bboxes = get_box_tensor(bboxes)
assert len(pred_bboxes) == 2
cls_preds, offset_preds = pred_bboxes
assert cls_preds.size(0) == bboxes.size(0) and offset_preds.size(
0) == bboxes.size(0)
bboxes, loc_confidence = bucket2bbox(bboxes, cls_preds, offset_preds,
self.num_buckets,
self.scale_factor, max_shape,
self.clip_border)
if self.use_box_type:
bboxes = HorizontalBoxes(bboxes, clone=False)
return bboxes, loc_confidence
def generat_buckets(proposals, num_buckets, scale_factor=1.0):
"""Generate buckets w.r.t bucket number and scale factor of proposals.
Args:
proposals (Tensor): Shape (n, 4)
num_buckets (int): Number of buckets.
scale_factor (float): Scale factor to rescale proposals.
Returns:
tuple[Tensor]: (bucket_w, bucket_h, l_buckets, r_buckets,
t_buckets, d_buckets)
- bucket_w: Width of buckets on x-axis. Shape (n, ).
- bucket_h: Height of buckets on y-axis. Shape (n, ).
- l_buckets: Left buckets. Shape (n, ceil(side_num/2)).
- r_buckets: Right buckets. Shape (n, ceil(side_num/2)).
- t_buckets: Top buckets. Shape (n, ceil(side_num/2)).
- d_buckets: Down buckets. Shape (n, ceil(side_num/2)).
"""
proposals = bbox_rescale(proposals, scale_factor)
# number of buckets in each side
side_num = int(np.ceil(num_buckets / 2.0))
pw = proposals[..., 2] - proposals[..., 0]
ph = proposals[..., 3] - proposals[..., 1]
px1 = proposals[..., 0]
py1 = proposals[..., 1]
px2 = proposals[..., 2]
py2 = proposals[..., 3]
bucket_w = pw / num_buckets
bucket_h = ph / num_buckets
# left buckets
l_buckets = px1[:, None] + (0.5 + torch.arange(
0, side_num).to(proposals).float())[None, :] * bucket_w[:, None]
# right buckets
r_buckets = px2[:, None] - (0.5 + torch.arange(
0, side_num).to(proposals).float())[None, :] * bucket_w[:, None]
# top buckets
t_buckets = py1[:, None] + (0.5 + torch.arange(
0, side_num).to(proposals).float())[None, :] * bucket_h[:, None]
# down buckets
d_buckets = py2[:, None] - (0.5 + torch.arange(
0, side_num).to(proposals).float())[None, :] * bucket_h[:, None]
return bucket_w, bucket_h, l_buckets, r_buckets, t_buckets, d_buckets
def bbox2bucket(proposals,
gt,
num_buckets,
scale_factor,
offset_topk=2,
offset_upperbound=1.0,
cls_ignore_neighbor=True):
"""Generate buckets estimation and fine regression targets.
Args:
proposals (Tensor): Shape (n, 4)
gt (Tensor): Shape (n, 4)
num_buckets (int): Number of buckets.
scale_factor (float): Scale factor to rescale proposals.
offset_topk (int): Topk buckets are used to generate
bucket fine regression targets. Defaults to 2.
offset_upperbound (float): Offset allowance to generate
bucket fine regression targets.
To avoid too large offset displacements. Defaults to 1.0.
cls_ignore_neighbor (bool): Ignore second nearest bucket or Not.
Defaults to True.
Returns:
tuple[Tensor]: (offsets, offsets_weights, bucket_labels, cls_weights).
- offsets: Fine regression targets. \
Shape (n, num_buckets*2).
- offsets_weights: Fine regression weights. \
Shape (n, num_buckets*2).
- bucket_labels: Bucketing estimation labels. \
Shape (n, num_buckets*2).
- cls_weights: Bucketing estimation weights. \
Shape (n, num_buckets*2).
"""
assert proposals.size() == gt.size()
# generate buckets
proposals = proposals.float()
gt = gt.float()
(bucket_w, bucket_h, l_buckets, r_buckets, t_buckets,
d_buckets) = generat_buckets(proposals, num_buckets, scale_factor)
gx1 = gt[..., 0]
gy1 = gt[..., 1]
gx2 = gt[..., 2]
gy2 = gt[..., 3]
# generate offset targets and weights
# offsets from buckets to gts
l_offsets = (l_buckets - gx1[:, None]) / bucket_w[:, None]
r_offsets = (r_buckets - gx2[:, None]) / bucket_w[:, None]
t_offsets = (t_buckets - gy1[:, None]) / bucket_h[:, None]
d_offsets = (d_buckets - gy2[:, None]) / bucket_h[:, None]
# select top-k nearest buckets
l_topk, l_label = l_offsets.abs().topk(
offset_topk, dim=1, largest=False, sorted=True)
r_topk, r_label = r_offsets.abs().topk(
offset_topk, dim=1, largest=False, sorted=True)
t_topk, t_label = t_offsets.abs().topk(
offset_topk, dim=1, largest=False, sorted=True)
d_topk, d_label = d_offsets.abs().topk(
offset_topk, dim=1, largest=False, sorted=True)
offset_l_weights = l_offsets.new_zeros(l_offsets.size())
offset_r_weights = r_offsets.new_zeros(r_offsets.size())
offset_t_weights = t_offsets.new_zeros(t_offsets.size())
offset_d_weights = d_offsets.new_zeros(d_offsets.size())
inds = torch.arange(0, proposals.size(0)).to(proposals).long()
# generate offset weights of top-k nearest buckets
for k in range(offset_topk):
if k >= 1:
offset_l_weights[inds, l_label[:,
k]] = (l_topk[:, k] <
offset_upperbound).float()
offset_r_weights[inds, r_label[:,
k]] = (r_topk[:, k] <
offset_upperbound).float()
offset_t_weights[inds, t_label[:,
k]] = (t_topk[:, k] <
offset_upperbound).float()
offset_d_weights[inds, d_label[:,
k]] = (d_topk[:, k] <
offset_upperbound).float()
else:
offset_l_weights[inds, l_label[:, k]] = 1.0
offset_r_weights[inds, r_label[:, k]] = 1.0
offset_t_weights[inds, t_label[:, k]] = 1.0
offset_d_weights[inds, d_label[:, k]] = 1.0
offsets = torch.cat([l_offsets, r_offsets, t_offsets, d_offsets], dim=-1)
offsets_weights = torch.cat([
offset_l_weights, offset_r_weights, offset_t_weights, offset_d_weights
],
dim=-1)
# generate bucket labels and weight
side_num = int(np.ceil(num_buckets / 2.0))
labels = torch.stack(
[l_label[:, 0], r_label[:, 0], t_label[:, 0], d_label[:, 0]], dim=-1)
batch_size = labels.size(0)
bucket_labels = F.one_hot(labels.view(-1), side_num).view(batch_size,
-1).float()
bucket_cls_l_weights = (l_offsets.abs() < 1).float()
bucket_cls_r_weights = (r_offsets.abs() < 1).float()
bucket_cls_t_weights = (t_offsets.abs() < 1).float()
bucket_cls_d_weights = (d_offsets.abs() < 1).float()
bucket_cls_weights = torch.cat([
bucket_cls_l_weights, bucket_cls_r_weights, bucket_cls_t_weights,
bucket_cls_d_weights
],
dim=-1)
# ignore second nearest buckets for cls if necessary
if cls_ignore_neighbor:
bucket_cls_weights = (~((bucket_cls_weights == 1) &
(bucket_labels == 0))).float()
else:
bucket_cls_weights[:] = 1.0
return offsets, offsets_weights, bucket_labels, bucket_cls_weights
def bucket2bbox(proposals,
cls_preds,
offset_preds,
num_buckets,
scale_factor=1.0,
max_shape=None,
clip_border=True):
"""Apply bucketing estimation (cls preds) and fine regression (offset
preds) to generate det bboxes.
Args:
proposals (Tensor): Boxes to be transformed. Shape (n, 4)
cls_preds (Tensor): bucketing estimation. Shape (n, num_buckets*2).
offset_preds (Tensor): fine regression. Shape (n, num_buckets*2).
num_buckets (int): Number of buckets.
scale_factor (float): Scale factor to rescale proposals.
max_shape (tuple[int, int]): Maximum bounds for boxes. specifies (H, W)
clip_border (bool, optional): Whether clip the objects outside the
border of the image. Defaults to True.
Returns:
tuple[Tensor]: (bboxes, loc_confidence).
- bboxes: predicted bboxes. Shape (n, 4)
- loc_confidence: localization confidence of predicted bboxes.
Shape (n,).
"""
side_num = int(np.ceil(num_buckets / 2.0))
cls_preds = cls_preds.view(-1, side_num)
offset_preds = offset_preds.view(-1, side_num)
scores = F.softmax(cls_preds, dim=1)
score_topk, score_label = scores.topk(2, dim=1, largest=True, sorted=True)
rescaled_proposals = bbox_rescale(proposals, scale_factor)
pw = rescaled_proposals[..., 2] - rescaled_proposals[..., 0]
ph = rescaled_proposals[..., 3] - rescaled_proposals[..., 1]
px1 = rescaled_proposals[..., 0]
py1 = rescaled_proposals[..., 1]
px2 = rescaled_proposals[..., 2]
py2 = rescaled_proposals[..., 3]
bucket_w = pw / num_buckets
bucket_h = ph / num_buckets
score_inds_l = score_label[0::4, 0]
score_inds_r = score_label[1::4, 0]
score_inds_t = score_label[2::4, 0]
score_inds_d = score_label[3::4, 0]
l_buckets = px1 + (0.5 + score_inds_l.float()) * bucket_w
r_buckets = px2 - (0.5 + score_inds_r.float()) * bucket_w
t_buckets = py1 + (0.5 + score_inds_t.float()) * bucket_h
d_buckets = py2 - (0.5 + score_inds_d.float()) * bucket_h
offsets = offset_preds.view(-1, 4, side_num)
inds = torch.arange(proposals.size(0)).to(proposals).long()
l_offsets = offsets[:, 0, :][inds, score_inds_l]
r_offsets = offsets[:, 1, :][inds, score_inds_r]
t_offsets = offsets[:, 2, :][inds, score_inds_t]
d_offsets = offsets[:, 3, :][inds, score_inds_d]
x1 = l_buckets - l_offsets * bucket_w
x2 = r_buckets - r_offsets * bucket_w
y1 = t_buckets - t_offsets * bucket_h
y2 = d_buckets - d_offsets * bucket_h
if clip_border and max_shape is not None:
x1 = x1.clamp(min=0, max=max_shape[1] - 1)
y1 = y1.clamp(min=0, max=max_shape[0] - 1)
x2 = x2.clamp(min=0, max=max_shape[1] - 1)
y2 = y2.clamp(min=0, max=max_shape[0] - 1)
bboxes = torch.cat([x1[:, None], y1[:, None], x2[:, None], y2[:, None]],
dim=-1)
# bucketing guided rescoring
loc_confidence = score_topk[:, 0]
top2_neighbor_inds = (score_label[:, 0] - score_label[:, 1]).abs() == 1
loc_confidence += score_topk[:, 1] * top2_neighbor_inds.float()
loc_confidence = loc_confidence.view(-1, 4).mean(dim=1)
return bboxes, loc_confidence
|