henry000 commited on
Commit
a26a166
Β·
2 Parent(s): 09d9de8 a87899e

πŸ”€ [Merge] branch 'TEST'

Browse files
yolo/tools/bbox_helper.py CHANGED
@@ -3,9 +3,11 @@ from typing import List, Tuple
3
 
4
  import torch
5
  import torch.nn.functional as F
 
6
  from torch import Tensor
 
7
 
8
- from yolo.config.config import MatcherConfig
9
 
10
 
11
  def calculate_iou(bbox1, bbox2, metrics="iou") -> Tensor:
@@ -122,6 +124,46 @@ def make_anchor(image_size: List[int], strides: List[int], device):
122
  return all_anchors, all_scalers
123
 
124
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  class BoxMatcher:
126
  def __init__(self, cfg: MatcherConfig, class_num: int, anchors: Tensor) -> None:
127
  self.class_num = class_num
@@ -224,11 +266,9 @@ class BoxMatcher:
224
  # get cls matrix (cls prob with each gt class and each predict class)
225
  cls_mat = self.get_cls_matrix(predict_cls.sigmoid(), target_cls)
226
 
227
- # TODO: alpha and beta should be set at hydra
228
  target_matrix = grid_mask * (iou_mat ** self.factor["iou"]) * (cls_mat ** self.factor["cls"])
229
 
230
  # choose topk
231
- # TODO: topk should be set at hydra
232
  topk_targets, topk_mask = self.filter_topk(target_matrix, topk=self.topk)
233
 
234
  # delete one anchor pred assign to mutliple gts
@@ -249,3 +289,24 @@ class BoxMatcher:
249
  align_cls = align_cls * normalize_term * valid_mask[:, :, None]
250
 
251
  return torch.cat([align_cls, align_bbox], dim=-1), valid_mask.bool()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
  import torch
5
  import torch.nn.functional as F
6
+ from einops import rearrange
7
  from torch import Tensor
8
+ from torchvision.ops import batched_nms
9
 
10
+ from yolo.config.config import Config, MatcherConfig
11
 
12
 
13
  def calculate_iou(bbox1, bbox2, metrics="iou") -> Tensor:
 
124
  return all_anchors, all_scalers
125
 
126
 
127
+ class Anchor2Box:
128
+ def __init__(self, cfg: Config, device: torch.device) -> None:
129
+ self.reg_max = cfg.model.anchor.reg_max
130
+ self.class_num = cfg.hyper.data.class_num
131
+ self.image_size = list(cfg.hyper.data.image_size)
132
+ self.strides = cfg.model.anchor.strides
133
+
134
+ self.scale_up = torch.tensor(self.image_size * 2, device=device)
135
+ self.anchors, self.scaler = make_anchor(self.image_size, self.strides, device)
136
+ self.reverse_reg = torch.arange(self.reg_max, dtype=torch.float32, device=device)
137
+
138
+ def __call__(self, predicts: List[Tensor], with_logits=False) -> Tensor:
139
+ """
140
+ args:
141
+ [B x AnchorClass x h1 x w1, B x AnchorClass x h2 x w2, B x AnchorClass x h3 x w3] // AnchorClass = 4 * 16 + 80
142
+ return:
143
+ [B x HW x ClassBbox] // HW = h1*w1 + h2*w2 + h3*w3, ClassBox = 80 + 4 (xyXY)
144
+ """
145
+ preds = []
146
+ for pred in predicts:
147
+ preds.append(rearrange(pred, "B AC h w -> B (h w) AC")) # B x AC x h x w-> B x hw x AC
148
+ preds = torch.concat(preds, dim=1) # -> B x (H W) x AC
149
+
150
+ preds_anc, preds_cls = torch.split(preds, (self.reg_max * 4, self.class_num), dim=-1)
151
+ preds_anc = rearrange(preds_anc, "B hw (P R)-> B hw P R", P=4)
152
+ if with_logits:
153
+ preds_cls = preds_cls.sigmoid()
154
+
155
+ pred_LTRB = preds_anc.softmax(dim=-1) @ self.reverse_reg * self.scaler.view(1, -1, 1)
156
+
157
+ lt, rb = pred_LTRB.chunk(2, dim=-1)
158
+ pred_minXY = self.anchors - lt
159
+ pred_maxXY = self.anchors + rb
160
+ preds_box = torch.cat([pred_minXY, pred_maxXY], dim=-1)
161
+
162
+ predicts = torch.cat([preds_cls, preds_box], dim=-1)
163
+
164
+ return predicts, preds_anc
165
+
166
+
167
  class BoxMatcher:
168
  def __init__(self, cfg: MatcherConfig, class_num: int, anchors: Tensor) -> None:
169
  self.class_num = class_num
 
266
  # get cls matrix (cls prob with each gt class and each predict class)
267
  cls_mat = self.get_cls_matrix(predict_cls.sigmoid(), target_cls)
268
 
 
269
  target_matrix = grid_mask * (iou_mat ** self.factor["iou"]) * (cls_mat ** self.factor["cls"])
270
 
271
  # choose topk
 
272
  topk_targets, topk_mask = self.filter_topk(target_matrix, topk=self.topk)
273
 
274
  # delete one anchor pred assign to mutliple gts
 
289
  align_cls = align_cls * normalize_term * valid_mask[:, :, None]
290
 
291
  return torch.cat([align_cls, align_bbox], dim=-1), valid_mask.bool()
292
+
293
+
294
+ def bbox_nms(predicts: Tensor, min_conf: float = 0, min_iou: float = 0.5):
295
+ cls_dist, bbox = predicts.split([80, 4], dim=-1)
296
+
297
+ # filter class by confidence
298
+ cls_val, cls_idx = cls_dist.max(dim=-1, keepdim=True)
299
+ valid_mask = cls_val > min_conf
300
+ valid_cls = cls_idx[valid_mask]
301
+ valid_box = bbox[valid_mask.repeat(1, 1, 4)].view(-1, 4)
302
+
303
+ batch_idx, *_ = torch.where(valid_mask)
304
+ nms_idx = batched_nms(valid_box, valid_cls, batch_idx, min_iou)
305
+ predicts_nms = []
306
+ for idx in range(batch_idx.max() + 1):
307
+ instance_idx = nms_idx[idx == batch_idx[nms_idx]]
308
+
309
+ predict_nms = torch.cat([valid_cls[instance_idx][:, None], valid_box[instance_idx]], dim=-1)
310
+
311
+ predicts_nms.append(predict_nms)
312
+ return predicts_nms
yolo/utils/drawer.py CHANGED
@@ -7,7 +7,9 @@ from PIL import Image, ImageDraw, ImageFont
7
  from torchvision.transforms.functional import to_pil_image
8
 
9
 
10
- def draw_bboxes(img: Union[Image.Image, torch.Tensor], bboxes: List[List[Union[int, float]]]):
 
 
11
  """
12
  Draw bounding boxes on an image.
13
 
@@ -30,16 +32,18 @@ def draw_bboxes(img: Union[Image.Image, torch.Tensor], bboxes: List[List[Union[i
30
 
31
  for bbox in bboxes:
32
  class_id, x_min, y_min, x_max, y_max = bbox
33
- x_min = x_min * width
34
- x_max = x_max * width
35
- y_min = y_min * height
36
- y_max = y_max * height
 
37
  shape = [(x_min, y_min), (x_max, y_max)]
38
  draw.rectangle(shape, outline="red", width=3)
39
  draw.text((x_min, y_min), str(int(class_id)), font=font, fill="blue")
40
 
41
  img.save("visualize.jpg") # Save the image with annotations
42
  logger.info("Saved visualize image at visualize.png")
 
43
 
44
 
45
  def draw_model(*, model_cfg=None, model=None, v7_base=False):
 
7
  from torchvision.transforms.functional import to_pil_image
8
 
9
 
10
+ def draw_bboxes(
11
+ img: Union[Image.Image, torch.Tensor], bboxes: List[List[Union[int, float]]], *, scaled_bbox: bool = True
12
+ ):
13
  """
14
  Draw bounding boxes on an image.
15
 
 
32
 
33
  for bbox in bboxes:
34
  class_id, x_min, y_min, x_max, y_max = bbox
35
+ if scaled_bbox:
36
+ x_min = x_min * width
37
+ x_max = x_max * width
38
+ y_min = y_min * height
39
+ y_max = y_max * height
40
  shape = [(x_min, y_min), (x_max, y_max)]
41
  draw.rectangle(shape, outline="red", width=3)
42
  draw.text((x_min, y_min), str(int(class_id)), font=font, fill="blue")
43
 
44
  img.save("visualize.jpg") # Save the image with annotations
45
  logger.info("Saved visualize image at visualize.png")
46
+ return img
47
 
48
 
49
  def draw_model(*, model_cfg=None, model=None, v7_base=False):
yolo/utils/loss.py CHANGED
@@ -8,12 +8,7 @@ from torch import Tensor, nn
8
  from torch.nn import BCEWithLogitsLoss
9
 
10
  from yolo.config.config import Config
11
- from yolo.tools.bbox_helper import (
12
- BoxMatcher,
13
- calculate_iou,
14
- make_anchor,
15
- transform_bbox,
16
- )
17
  from yolo.tools.module_helper import make_chunk
18
 
19
 
@@ -90,42 +85,7 @@ class YOLOLoss:
90
  self.iou = BoxLoss()
91
 
92
  self.matcher = BoxMatcher(cfg.hyper.train.loss.matcher, self.class_num, self.anchors)
93
-
94
- def parse_predicts(self, predicts: List[Tensor]) -> Tensor:
95
- """
96
- args:
97
- [B x AnchorClass x h1 x w1, B x AnchorClass x h2 x w2, B x AnchorClass x h3 x w3] // AnchorClass = 4 * 16 + 80
98
- return:
99
- [B x HW x ClassBbox] // HW = h1*w1 + h2*w2 + h3*w3, ClassBox = 80 + 4 (xyXY)
100
- """
101
- preds = []
102
- for pred in predicts:
103
- preds.append(rearrange(pred, "B AC h w -> B (h w) AC")) # B x AC x h x w-> B x hw x AC
104
- preds = torch.concat(preds, dim=1) # -> B x (H W) x AC
105
-
106
- preds_anc, preds_cls = torch.split(preds, (self.reg_max * 4, self.class_num), dim=-1)
107
- preds_anc = rearrange(preds_anc, "B hw (P R)-> B hw P R", P=4)
108
-
109
- pred_LTRB = preds_anc.softmax(dim=-1) @ self.reverse_reg * self.scaler.view(1, -1, 1)
110
-
111
- lt, rb = pred_LTRB.chunk(2, dim=-1)
112
- pred_minXY = self.anchors - lt
113
- pred_maxXY = self.anchors + rb
114
- predicts = torch.cat([preds_cls, pred_minXY, pred_maxXY], dim=-1)
115
-
116
- return predicts, preds_anc
117
-
118
- def parse_targets(self, targets: Tensor, batch_size: int = 16) -> List[Tensor]:
119
- """
120
- return List:
121
- """
122
- targets[:, 2:] = transform_bbox(targets[:, 2:], "xycwh -> xyxy") * self.scale_up
123
- bbox_num = targets[:, 0].int().bincount()
124
- batch_targets = torch.zeros(batch_size, bbox_num.max(), 5, device=targets.device)
125
- for instance_idx, bbox_num in enumerate(bbox_num):
126
- instance_targets = targets[targets[:, 0] == instance_idx]
127
- batch_targets[instance_idx, :bbox_num] = instance_targets[:, 1:].detach()
128
- return batch_targets
129
 
130
  def separate_anchor(self, anchors):
131
  """
@@ -138,10 +98,10 @@ class YOLOLoss:
138
  def __call__(self, predicts: List[Tensor], targets: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
139
  # Batch_Size x (Anchor + Class) x H x W
140
  # TODO: check datatype, why targets has a little bit error with origin version
141
- predicts, predicts_anc = self.parse_predicts(predicts)
142
 
 
143
  align_targets, valid_masks = self.matcher(targets, predicts)
144
- # calculate loss between with instance and predict
145
 
146
  targets_cls, targets_bbox = self.separate_anchor(align_targets)
147
  predicts_cls, predicts_bbox = self.separate_anchor(predicts)
 
8
  from torch.nn import BCEWithLogitsLoss
9
 
10
  from yolo.config.config import Config
11
+ from yolo.tools.bbox_helper import Anchor2Box, BoxMatcher, calculate_iou, make_anchor
 
 
 
 
 
12
  from yolo.tools.module_helper import make_chunk
13
 
14
 
 
85
  self.iou = BoxLoss()
86
 
87
  self.matcher = BoxMatcher(cfg.hyper.train.loss.matcher, self.class_num, self.anchors)
88
+ self.box_converter = Anchor2Box(cfg, device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
  def separate_anchor(self, anchors):
91
  """
 
98
  def __call__(self, predicts: List[Tensor], targets: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
99
  # Batch_Size x (Anchor + Class) x H x W
100
  # TODO: check datatype, why targets has a little bit error with origin version
101
+ predicts, predicts_anc = self.box_converter(predicts)
102
 
103
+ # For each predicted targets, assign a best suitable ground truth box.
104
  align_targets, valid_masks = self.matcher(targets, predicts)
 
105
 
106
  targets_cls, targets_bbox = self.separate_anchor(align_targets)
107
  predicts_cls, predicts_bbox = self.separate_anchor(predicts)