π [Merge] branch 'TEST'
Browse files- yolo/tools/bbox_helper.py +64 -3
- yolo/utils/drawer.py +9 -5
- yolo/utils/loss.py +4 -44
yolo/tools/bbox_helper.py
CHANGED
@@ -3,9 +3,11 @@ from typing import List, Tuple
|
|
3 |
|
4 |
import torch
|
5 |
import torch.nn.functional as F
|
|
|
6 |
from torch import Tensor
|
|
|
7 |
|
8 |
-
from yolo.config.config import MatcherConfig
|
9 |
|
10 |
|
11 |
def calculate_iou(bbox1, bbox2, metrics="iou") -> Tensor:
|
@@ -122,6 +124,46 @@ def make_anchor(image_size: List[int], strides: List[int], device):
|
|
122 |
return all_anchors, all_scalers
|
123 |
|
124 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
125 |
class BoxMatcher:
|
126 |
def __init__(self, cfg: MatcherConfig, class_num: int, anchors: Tensor) -> None:
|
127 |
self.class_num = class_num
|
@@ -224,11 +266,9 @@ class BoxMatcher:
|
|
224 |
# get cls matrix (cls prob with each gt class and each predict class)
|
225 |
cls_mat = self.get_cls_matrix(predict_cls.sigmoid(), target_cls)
|
226 |
|
227 |
-
# TODO: alpha and beta should be set at hydra
|
228 |
target_matrix = grid_mask * (iou_mat ** self.factor["iou"]) * (cls_mat ** self.factor["cls"])
|
229 |
|
230 |
# choose topk
|
231 |
-
# TODO: topk should be set at hydra
|
232 |
topk_targets, topk_mask = self.filter_topk(target_matrix, topk=self.topk)
|
233 |
|
234 |
# delete one anchor pred assign to mutliple gts
|
@@ -249,3 +289,24 @@ class BoxMatcher:
|
|
249 |
align_cls = align_cls * normalize_term * valid_mask[:, :, None]
|
250 |
|
251 |
return torch.cat([align_cls, align_bbox], dim=-1), valid_mask.bool()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
|
4 |
import torch
|
5 |
import torch.nn.functional as F
|
6 |
+
from einops import rearrange
|
7 |
from torch import Tensor
|
8 |
+
from torchvision.ops import batched_nms
|
9 |
|
10 |
+
from yolo.config.config import Config, MatcherConfig
|
11 |
|
12 |
|
13 |
def calculate_iou(bbox1, bbox2, metrics="iou") -> Tensor:
|
|
|
124 |
return all_anchors, all_scalers
|
125 |
|
126 |
|
127 |
+
class Anchor2Box:
|
128 |
+
def __init__(self, cfg: Config, device: torch.device) -> None:
|
129 |
+
self.reg_max = cfg.model.anchor.reg_max
|
130 |
+
self.class_num = cfg.hyper.data.class_num
|
131 |
+
self.image_size = list(cfg.hyper.data.image_size)
|
132 |
+
self.strides = cfg.model.anchor.strides
|
133 |
+
|
134 |
+
self.scale_up = torch.tensor(self.image_size * 2, device=device)
|
135 |
+
self.anchors, self.scaler = make_anchor(self.image_size, self.strides, device)
|
136 |
+
self.reverse_reg = torch.arange(self.reg_max, dtype=torch.float32, device=device)
|
137 |
+
|
138 |
+
def __call__(self, predicts: List[Tensor], with_logits=False) -> Tensor:
|
139 |
+
"""
|
140 |
+
args:
|
141 |
+
[B x AnchorClass x h1 x w1, B x AnchorClass x h2 x w2, B x AnchorClass x h3 x w3] // AnchorClass = 4 * 16 + 80
|
142 |
+
return:
|
143 |
+
[B x HW x ClassBbox] // HW = h1*w1 + h2*w2 + h3*w3, ClassBox = 80 + 4 (xyXY)
|
144 |
+
"""
|
145 |
+
preds = []
|
146 |
+
for pred in predicts:
|
147 |
+
preds.append(rearrange(pred, "B AC h w -> B (h w) AC")) # B x AC x h x w-> B x hw x AC
|
148 |
+
preds = torch.concat(preds, dim=1) # -> B x (H W) x AC
|
149 |
+
|
150 |
+
preds_anc, preds_cls = torch.split(preds, (self.reg_max * 4, self.class_num), dim=-1)
|
151 |
+
preds_anc = rearrange(preds_anc, "B hw (P R)-> B hw P R", P=4)
|
152 |
+
if with_logits:
|
153 |
+
preds_cls = preds_cls.sigmoid()
|
154 |
+
|
155 |
+
pred_LTRB = preds_anc.softmax(dim=-1) @ self.reverse_reg * self.scaler.view(1, -1, 1)
|
156 |
+
|
157 |
+
lt, rb = pred_LTRB.chunk(2, dim=-1)
|
158 |
+
pred_minXY = self.anchors - lt
|
159 |
+
pred_maxXY = self.anchors + rb
|
160 |
+
preds_box = torch.cat([pred_minXY, pred_maxXY], dim=-1)
|
161 |
+
|
162 |
+
predicts = torch.cat([preds_cls, preds_box], dim=-1)
|
163 |
+
|
164 |
+
return predicts, preds_anc
|
165 |
+
|
166 |
+
|
167 |
class BoxMatcher:
|
168 |
def __init__(self, cfg: MatcherConfig, class_num: int, anchors: Tensor) -> None:
|
169 |
self.class_num = class_num
|
|
|
266 |
# get cls matrix (cls prob with each gt class and each predict class)
|
267 |
cls_mat = self.get_cls_matrix(predict_cls.sigmoid(), target_cls)
|
268 |
|
|
|
269 |
target_matrix = grid_mask * (iou_mat ** self.factor["iou"]) * (cls_mat ** self.factor["cls"])
|
270 |
|
271 |
# choose topk
|
|
|
272 |
topk_targets, topk_mask = self.filter_topk(target_matrix, topk=self.topk)
|
273 |
|
274 |
# delete one anchor pred assign to mutliple gts
|
|
|
289 |
align_cls = align_cls * normalize_term * valid_mask[:, :, None]
|
290 |
|
291 |
return torch.cat([align_cls, align_bbox], dim=-1), valid_mask.bool()
|
292 |
+
|
293 |
+
|
294 |
+
def bbox_nms(predicts: Tensor, min_conf: float = 0, min_iou: float = 0.5):
|
295 |
+
cls_dist, bbox = predicts.split([80, 4], dim=-1)
|
296 |
+
|
297 |
+
# filter class by confidence
|
298 |
+
cls_val, cls_idx = cls_dist.max(dim=-1, keepdim=True)
|
299 |
+
valid_mask = cls_val > min_conf
|
300 |
+
valid_cls = cls_idx[valid_mask]
|
301 |
+
valid_box = bbox[valid_mask.repeat(1, 1, 4)].view(-1, 4)
|
302 |
+
|
303 |
+
batch_idx, *_ = torch.where(valid_mask)
|
304 |
+
nms_idx = batched_nms(valid_box, valid_cls, batch_idx, min_iou)
|
305 |
+
predicts_nms = []
|
306 |
+
for idx in range(batch_idx.max() + 1):
|
307 |
+
instance_idx = nms_idx[idx == batch_idx[nms_idx]]
|
308 |
+
|
309 |
+
predict_nms = torch.cat([valid_cls[instance_idx][:, None], valid_box[instance_idx]], dim=-1)
|
310 |
+
|
311 |
+
predicts_nms.append(predict_nms)
|
312 |
+
return predicts_nms
|
yolo/utils/drawer.py
CHANGED
@@ -7,7 +7,9 @@ from PIL import Image, ImageDraw, ImageFont
|
|
7 |
from torchvision.transforms.functional import to_pil_image
|
8 |
|
9 |
|
10 |
-
def draw_bboxes(
|
|
|
|
|
11 |
"""
|
12 |
Draw bounding boxes on an image.
|
13 |
|
@@ -30,16 +32,18 @@ def draw_bboxes(img: Union[Image.Image, torch.Tensor], bboxes: List[List[Union[i
|
|
30 |
|
31 |
for bbox in bboxes:
|
32 |
class_id, x_min, y_min, x_max, y_max = bbox
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
|
|
37 |
shape = [(x_min, y_min), (x_max, y_max)]
|
38 |
draw.rectangle(shape, outline="red", width=3)
|
39 |
draw.text((x_min, y_min), str(int(class_id)), font=font, fill="blue")
|
40 |
|
41 |
img.save("visualize.jpg") # Save the image with annotations
|
42 |
logger.info("Saved visualize image at visualize.png")
|
|
|
43 |
|
44 |
|
45 |
def draw_model(*, model_cfg=None, model=None, v7_base=False):
|
|
|
7 |
from torchvision.transforms.functional import to_pil_image
|
8 |
|
9 |
|
10 |
+
def draw_bboxes(
|
11 |
+
img: Union[Image.Image, torch.Tensor], bboxes: List[List[Union[int, float]]], *, scaled_bbox: bool = True
|
12 |
+
):
|
13 |
"""
|
14 |
Draw bounding boxes on an image.
|
15 |
|
|
|
32 |
|
33 |
for bbox in bboxes:
|
34 |
class_id, x_min, y_min, x_max, y_max = bbox
|
35 |
+
if scaled_bbox:
|
36 |
+
x_min = x_min * width
|
37 |
+
x_max = x_max * width
|
38 |
+
y_min = y_min * height
|
39 |
+
y_max = y_max * height
|
40 |
shape = [(x_min, y_min), (x_max, y_max)]
|
41 |
draw.rectangle(shape, outline="red", width=3)
|
42 |
draw.text((x_min, y_min), str(int(class_id)), font=font, fill="blue")
|
43 |
|
44 |
img.save("visualize.jpg") # Save the image with annotations
|
45 |
logger.info("Saved visualize image at visualize.png")
|
46 |
+
return img
|
47 |
|
48 |
|
49 |
def draw_model(*, model_cfg=None, model=None, v7_base=False):
|
yolo/utils/loss.py
CHANGED
@@ -8,12 +8,7 @@ from torch import Tensor, nn
|
|
8 |
from torch.nn import BCEWithLogitsLoss
|
9 |
|
10 |
from yolo.config.config import Config
|
11 |
-
from yolo.tools.bbox_helper import
|
12 |
-
BoxMatcher,
|
13 |
-
calculate_iou,
|
14 |
-
make_anchor,
|
15 |
-
transform_bbox,
|
16 |
-
)
|
17 |
from yolo.tools.module_helper import make_chunk
|
18 |
|
19 |
|
@@ -90,42 +85,7 @@ class YOLOLoss:
|
|
90 |
self.iou = BoxLoss()
|
91 |
|
92 |
self.matcher = BoxMatcher(cfg.hyper.train.loss.matcher, self.class_num, self.anchors)
|
93 |
-
|
94 |
-
def parse_predicts(self, predicts: List[Tensor]) -> Tensor:
|
95 |
-
"""
|
96 |
-
args:
|
97 |
-
[B x AnchorClass x h1 x w1, B x AnchorClass x h2 x w2, B x AnchorClass x h3 x w3] // AnchorClass = 4 * 16 + 80
|
98 |
-
return:
|
99 |
-
[B x HW x ClassBbox] // HW = h1*w1 + h2*w2 + h3*w3, ClassBox = 80 + 4 (xyXY)
|
100 |
-
"""
|
101 |
-
preds = []
|
102 |
-
for pred in predicts:
|
103 |
-
preds.append(rearrange(pred, "B AC h w -> B (h w) AC")) # B x AC x h x w-> B x hw x AC
|
104 |
-
preds = torch.concat(preds, dim=1) # -> B x (H W) x AC
|
105 |
-
|
106 |
-
preds_anc, preds_cls = torch.split(preds, (self.reg_max * 4, self.class_num), dim=-1)
|
107 |
-
preds_anc = rearrange(preds_anc, "B hw (P R)-> B hw P R", P=4)
|
108 |
-
|
109 |
-
pred_LTRB = preds_anc.softmax(dim=-1) @ self.reverse_reg * self.scaler.view(1, -1, 1)
|
110 |
-
|
111 |
-
lt, rb = pred_LTRB.chunk(2, dim=-1)
|
112 |
-
pred_minXY = self.anchors - lt
|
113 |
-
pred_maxXY = self.anchors + rb
|
114 |
-
predicts = torch.cat([preds_cls, pred_minXY, pred_maxXY], dim=-1)
|
115 |
-
|
116 |
-
return predicts, preds_anc
|
117 |
-
|
118 |
-
def parse_targets(self, targets: Tensor, batch_size: int = 16) -> List[Tensor]:
|
119 |
-
"""
|
120 |
-
return List:
|
121 |
-
"""
|
122 |
-
targets[:, 2:] = transform_bbox(targets[:, 2:], "xycwh -> xyxy") * self.scale_up
|
123 |
-
bbox_num = targets[:, 0].int().bincount()
|
124 |
-
batch_targets = torch.zeros(batch_size, bbox_num.max(), 5, device=targets.device)
|
125 |
-
for instance_idx, bbox_num in enumerate(bbox_num):
|
126 |
-
instance_targets = targets[targets[:, 0] == instance_idx]
|
127 |
-
batch_targets[instance_idx, :bbox_num] = instance_targets[:, 1:].detach()
|
128 |
-
return batch_targets
|
129 |
|
130 |
def separate_anchor(self, anchors):
|
131 |
"""
|
@@ -138,10 +98,10 @@ class YOLOLoss:
|
|
138 |
def __call__(self, predicts: List[Tensor], targets: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
|
139 |
# Batch_Size x (Anchor + Class) x H x W
|
140 |
# TODO: check datatype, why targets has a little bit error with origin version
|
141 |
-
predicts, predicts_anc = self.
|
142 |
|
|
|
143 |
align_targets, valid_masks = self.matcher(targets, predicts)
|
144 |
-
# calculate loss between with instance and predict
|
145 |
|
146 |
targets_cls, targets_bbox = self.separate_anchor(align_targets)
|
147 |
predicts_cls, predicts_bbox = self.separate_anchor(predicts)
|
|
|
8 |
from torch.nn import BCEWithLogitsLoss
|
9 |
|
10 |
from yolo.config.config import Config
|
11 |
+
from yolo.tools.bbox_helper import Anchor2Box, BoxMatcher, calculate_iou, make_anchor
|
|
|
|
|
|
|
|
|
|
|
12 |
from yolo.tools.module_helper import make_chunk
|
13 |
|
14 |
|
|
|
85 |
self.iou = BoxLoss()
|
86 |
|
87 |
self.matcher = BoxMatcher(cfg.hyper.train.loss.matcher, self.class_num, self.anchors)
|
88 |
+
self.box_converter = Anchor2Box(cfg, device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
|
90 |
def separate_anchor(self, anchors):
|
91 |
"""
|
|
|
98 |
def __call__(self, predicts: List[Tensor], targets: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
|
99 |
# Batch_Size x (Anchor + Class) x H x W
|
100 |
# TODO: check datatype, why targets has a little bit error with origin version
|
101 |
+
predicts, predicts_anc = self.box_converter(predicts)
|
102 |
|
103 |
+
# For each predicted targets, assign a best suitable ground truth box.
|
104 |
align_targets, valid_masks = self.matcher(targets, predicts)
|
|
|
105 |
|
106 |
targets_cls, targets_bbox = self.separate_anchor(align_targets)
|
107 |
predicts_cls, predicts_bbox = self.separate_anchor(predicts)
|