๐ [Merge] branch 'MODELv2' into TEST
Browse files- yolo/utils/bounding_box_utils.py +32 -21
yolo/utils/bounding_box_utils.py
CHANGED
@@ -108,12 +108,13 @@ def transform_bbox(bbox: Tensor, indicator="xywh -> xyxy"):
|
|
108 |
return bbox.to(dtype=data_type)
|
109 |
|
110 |
|
111 |
-
def generate_anchors(image_size: List[int],
|
112 |
"""
|
113 |
Find the anchor maps for each w, h.
|
114 |
|
115 |
Args:
|
116 |
-
|
|
|
117 |
|
118 |
Returns:
|
119 |
all_anchors [HW x 2]:
|
@@ -122,15 +123,14 @@ def generate_anchors(image_size: List[int], anchors_list: List[Tuple[int]]):
|
|
122 |
W, H = image_size
|
123 |
anchors = []
|
124 |
scaler = []
|
125 |
-
for
|
126 |
-
|
127 |
-
anchor_num = anchor_wh[0] * anchor_wh[1]
|
128 |
scaler.append(torch.full((anchor_num,), stride))
|
129 |
shift = stride // 2
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
anchor = torch.stack([
|
134 |
anchors.append(anchor)
|
135 |
all_anchors = torch.cat(anchors, dim=0)
|
136 |
all_scalers = torch.cat(scaler, dim=0)
|
@@ -172,6 +172,7 @@ class BoxMatcher:
|
|
172 |
Returns:
|
173 |
[batch x targets x anchors]: The probabilities from `pred_cls` corresponding to the class indices specified in `target_cls`.
|
174 |
"""
|
|
|
175 |
target_cls = target_cls.expand(-1, -1, 8400)
|
176 |
predict_cls = predict_cls.transpose(1, 2)
|
177 |
cls_probabilities = torch.gather(predict_cls, 1, target_cls)
|
@@ -266,24 +267,34 @@ class BoxMatcher:
|
|
266 |
|
267 |
class Vec2Box:
|
268 |
def __init__(self, model: YOLO, image_size, device):
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
-
|
273 |
-
|
274 |
-
for predict_head in dummy_output["Main"]:
|
275 |
-
_, _, *anchor_num = predict_head[2].shape
|
276 |
-
anchors_num.append(anchor_num)
|
277 |
else:
|
278 |
-
logger.info(
|
279 |
-
|
280 |
|
|
|
281 |
if not isinstance(model, YOLO):
|
282 |
device = torch.device("cpu")
|
283 |
|
284 |
-
anchor_grid, scaler = generate_anchors(image_size,
|
285 |
self.anchor_grid, self.scaler = anchor_grid.to(device), scaler.to(device)
|
286 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
287 |
|
288 |
def __call__(self, predicts):
|
289 |
preds_cls, preds_anc, preds_box = [], [], []
|
|
|
108 |
return bbox.to(dtype=data_type)
|
109 |
|
110 |
|
111 |
+
def generate_anchors(image_size: List[int], strides: List[int]):
|
112 |
"""
|
113 |
Find the anchor maps for each w, h.
|
114 |
|
115 |
Args:
|
116 |
+
image_size List: the image size of augmented image size
|
117 |
+
strides List[8, 16, 32, ...]: the stride size for each predicted layer
|
118 |
|
119 |
Returns:
|
120 |
all_anchors [HW x 2]:
|
|
|
123 |
W, H = image_size
|
124 |
anchors = []
|
125 |
scaler = []
|
126 |
+
for stride in strides:
|
127 |
+
anchor_num = W // stride * H // stride
|
|
|
128 |
scaler.append(torch.full((anchor_num,), stride))
|
129 |
shift = stride // 2
|
130 |
+
h = torch.arange(0, H, stride) + shift
|
131 |
+
w = torch.arange(0, W, stride) + shift
|
132 |
+
anchor_h, anchor_w = torch.meshgrid(h, w, indexing="ij")
|
133 |
+
anchor = torch.stack([anchor_w.flatten(), anchor_h.flatten()], dim=-1)
|
134 |
anchors.append(anchor)
|
135 |
all_anchors = torch.cat(anchors, dim=0)
|
136 |
all_scalers = torch.cat(scaler, dim=0)
|
|
|
172 |
Returns:
|
173 |
[batch x targets x anchors]: The probabilities from `pred_cls` corresponding to the class indices specified in `target_cls`.
|
174 |
"""
|
175 |
+
# TODO: Turn 8400 to HW
|
176 |
target_cls = target_cls.expand(-1, -1, 8400)
|
177 |
predict_cls = predict_cls.transpose(1, 2)
|
178 |
cls_probabilities = torch.gather(predict_cls, 1, target_cls)
|
|
|
267 |
|
268 |
class Vec2Box:
|
269 |
def __init__(self, model: YOLO, image_size, device):
|
270 |
+
self.device = device
|
271 |
+
|
272 |
+
if getattr(model, "strides"):
|
273 |
+
logger.info(f"๐ถ Found stride of model {model.strides}")
|
274 |
+
self.strides = model.strides
|
|
|
|
|
|
|
275 |
else:
|
276 |
+
logger.info("๐งธ Found no stride of model, performed a dummy test for auto-anchor size")
|
277 |
+
self.strides = self.create_auto_anchor(model, image_size)
|
278 |
|
279 |
+
# TODO: this is a exception of onnx, remove it when onnx device if fixed
|
280 |
if not isinstance(model, YOLO):
|
281 |
device = torch.device("cpu")
|
282 |
|
283 |
+
anchor_grid, scaler = generate_anchors(image_size, self.strides)
|
284 |
self.anchor_grid, self.scaler = anchor_grid.to(device), scaler.to(device)
|
285 |
+
|
286 |
+
def create_auto_anchor(self, model: YOLO, image_size):
|
287 |
+
dummy_input = torch.zeros(1, 3, *image_size).to(self.device)
|
288 |
+
dummy_output = model(dummy_input)
|
289 |
+
strides = []
|
290 |
+
for predict_head in dummy_output["Main"]:
|
291 |
+
_, _, *anchor_num = predict_head[2].shape
|
292 |
+
strides.append(image_size[1] // anchor_num[1])
|
293 |
+
return strides
|
294 |
+
|
295 |
+
def update(self, image_size):
|
296 |
+
anchor_grid, scaler = generate_anchors(image_size, self.strides)
|
297 |
+
self.anchor_grid, self.scaler = anchor_grid.to(self.device), scaler.to(self.device)
|
298 |
|
299 |
def __call__(self, predicts):
|
300 |
preds_cls, preds_anc, preds_box = [], [], []
|