henry000 commited on
Commit
253e9b1
Β·
1 Parent(s): ac20d16

🚚 [Move] parse_predict to Anchor2Box class

Browse files
Files changed (2) hide show
  1. yolo/tools/bbox_helper.py +42 -1
  2. yolo/utils/loss.py +4 -44
yolo/tools/bbox_helper.py CHANGED
@@ -3,9 +3,10 @@ from typing import List, Tuple
3
 
4
  import torch
5
  import torch.nn.functional as F
 
6
  from torch import Tensor
7
 
8
- from yolo.config.config import MatcherConfig
9
 
10
 
11
  def calculate_iou(bbox1, bbox2, metrics="iou") -> Tensor:
@@ -122,6 +123,46 @@ def make_anchor(image_size: List[int], strides: List[int], device):
122
  return all_anchors, all_scalers
123
 
124
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  class BoxMatcher:
126
  def __init__(self, cfg: MatcherConfig, class_num: int, anchors: Tensor) -> None:
127
  self.class_num = class_num
 
3
 
4
  import torch
5
  import torch.nn.functional as F
6
+ from einops import rearrange
7
  from torch import Tensor
8
 
9
+ from yolo.config.config import Config, MatcherConfig
10
 
11
 
12
  def calculate_iou(bbox1, bbox2, metrics="iou") -> Tensor:
 
123
  return all_anchors, all_scalers
124
 
125
 
126
+ class Anchor2Box:
127
+ def __init__(self, cfg: Config, device: torch.device) -> None:
128
+ self.reg_max = cfg.model.anchor.reg_max
129
+ self.class_num = cfg.hyper.data.class_num
130
+ self.image_size = list(cfg.hyper.data.image_size)
131
+ self.strides = cfg.model.anchor.strides
132
+
133
+ self.scale_up = torch.tensor(self.image_size * 2, device=device)
134
+ self.anchors, self.scaler = make_anchor(self.image_size, self.strides, device)
135
+ self.reverse_reg = torch.arange(self.reg_max, dtype=torch.float32, device=device)
136
+
137
+ def __call__(self, predicts: List[Tensor], with_logits=False) -> Tensor:
138
+ """
139
+ args:
140
+ [B x AnchorClass x h1 x w1, B x AnchorClass x h2 x w2, B x AnchorClass x h3 x w3] // AnchorClass = 4 * 16 + 80
141
+ return:
142
+ [B x HW x ClassBbox] // HW = h1*w1 + h2*w2 + h3*w3, ClassBox = 80 + 4 (xyXY)
143
+ """
144
+ preds = []
145
+ for pred in predicts:
146
+ preds.append(rearrange(pred, "B AC h w -> B (h w) AC")) # B x AC x h x w-> B x hw x AC
147
+ preds = torch.concat(preds, dim=1) # -> B x (H W) x AC
148
+
149
+ preds_anc, preds_cls = torch.split(preds, (self.reg_max * 4, self.class_num), dim=-1)
150
+ preds_anc = rearrange(preds_anc, "B hw (P R)-> B hw P R", P=4)
151
+ if with_logits:
152
+ preds_cls = preds_cls.sigmoid()
153
+
154
+ pred_LTRB = preds_anc.softmax(dim=-1) @ self.reverse_reg * self.scaler.view(1, -1, 1)
155
+
156
+ lt, rb = pred_LTRB.chunk(2, dim=-1)
157
+ pred_minXY = self.anchors - lt
158
+ pred_maxXY = self.anchors + rb
159
+ preds_box = torch.cat([pred_minXY, pred_maxXY], dim=-1)
160
+
161
+ predicts = torch.cat([preds_cls, preds_box], dim=-1)
162
+
163
+ return predicts, preds_anc
164
+
165
+
166
  class BoxMatcher:
167
  def __init__(self, cfg: MatcherConfig, class_num: int, anchors: Tensor) -> None:
168
  self.class_num = class_num
yolo/utils/loss.py CHANGED
@@ -8,12 +8,7 @@ from torch import Tensor, nn
8
  from torch.nn import BCEWithLogitsLoss
9
 
10
  from yolo.config.config import Config
11
- from yolo.tools.bbox_helper import (
12
- BoxMatcher,
13
- calculate_iou,
14
- make_anchor,
15
- transform_bbox,
16
- )
17
  from yolo.tools.module_helper import make_chunk
18
 
19
 
@@ -90,42 +85,7 @@ class YOLOLoss:
90
  self.iou = BoxLoss()
91
 
92
  self.matcher = BoxMatcher(cfg.hyper.train.loss.matcher, self.class_num, self.anchors)
93
-
94
- def parse_predicts(self, predicts: List[Tensor]) -> Tensor:
95
- """
96
- args:
97
- [B x AnchorClass x h1 x w1, B x AnchorClass x h2 x w2, B x AnchorClass x h3 x w3] // AnchorClass = 4 * 16 + 80
98
- return:
99
- [B x HW x ClassBbox] // HW = h1*w1 + h2*w2 + h3*w3, ClassBox = 80 + 4 (xyXY)
100
- """
101
- preds = []
102
- for pred in predicts:
103
- preds.append(rearrange(pred, "B AC h w -> B (h w) AC")) # B x AC x h x w-> B x hw x AC
104
- preds = torch.concat(preds, dim=1) # -> B x (H W) x AC
105
-
106
- preds_anc, preds_cls = torch.split(preds, (self.reg_max * 4, self.class_num), dim=-1)
107
- preds_anc = rearrange(preds_anc, "B hw (P R)-> B hw P R", P=4)
108
-
109
- pred_LTRB = preds_anc.softmax(dim=-1) @ self.reverse_reg * self.scaler.view(1, -1, 1)
110
-
111
- lt, rb = pred_LTRB.chunk(2, dim=-1)
112
- pred_minXY = self.anchors - lt
113
- pred_maxXY = self.anchors + rb
114
- predicts = torch.cat([preds_cls, pred_minXY, pred_maxXY], dim=-1)
115
-
116
- return predicts, preds_anc
117
-
118
- def parse_targets(self, targets: Tensor, batch_size: int = 16) -> List[Tensor]:
119
- """
120
- return List:
121
- """
122
- targets[:, 2:] = transform_bbox(targets[:, 2:], "xycwh -> xyxy") * self.scale_up
123
- bbox_num = targets[:, 0].int().bincount()
124
- batch_targets = torch.zeros(batch_size, bbox_num.max(), 5, device=targets.device)
125
- for instance_idx, bbox_num in enumerate(bbox_num):
126
- instance_targets = targets[targets[:, 0] == instance_idx]
127
- batch_targets[instance_idx, :bbox_num] = instance_targets[:, 1:].detach()
128
- return batch_targets
129
 
130
  def separate_anchor(self, anchors):
131
  """
@@ -138,10 +98,10 @@ class YOLOLoss:
138
  def __call__(self, predicts: List[Tensor], targets: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
139
  # Batch_Size x (Anchor + Class) x H x W
140
  # TODO: check datatype, why targets has a little bit error with origin version
141
- predicts, predicts_anc = self.parse_predicts(predicts)
142
 
 
143
  align_targets, valid_masks = self.matcher(targets, predicts)
144
- # calculate loss between with instance and predict
145
 
146
  targets_cls, targets_bbox = self.separate_anchor(align_targets)
147
  predicts_cls, predicts_bbox = self.separate_anchor(predicts)
 
8
  from torch.nn import BCEWithLogitsLoss
9
 
10
  from yolo.config.config import Config
11
+ from yolo.tools.bbox_helper import Anchor2Box, BoxMatcher, calculate_iou, make_anchor
 
 
 
 
 
12
  from yolo.tools.module_helper import make_chunk
13
 
14
 
 
85
  self.iou = BoxLoss()
86
 
87
  self.matcher = BoxMatcher(cfg.hyper.train.loss.matcher, self.class_num, self.anchors)
88
+ self.box_converter = Anchor2Box(cfg, device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
  def separate_anchor(self, anchors):
91
  """
 
98
  def __call__(self, predicts: List[Tensor], targets: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
99
  # Batch_Size x (Anchor + Class) x H x W
100
  # TODO: check datatype, why targets has a little bit error with origin version
101
+ predicts, predicts_anc = self.box_converter(predicts)
102
 
103
+ # For each predicted targets, assign a best suitable ground truth box.
104
  align_targets, valid_masks = self.matcher(targets, predicts)
 
105
 
106
  targets_cls, targets_bbox = self.separate_anchor(align_targets)
107
  predicts_cls, predicts_bbox = self.separate_anchor(predicts)