henry000 commited on
Commit
c9338ee
Β·
1 Parent(s): 6e46676

🎨 [Add] Dual Loss for yolov9

Browse files
yolo/config/config.py CHANGED
@@ -72,7 +72,7 @@ class MatcherConfig:
72
  @dataclass
73
  class LossConfig:
74
  objective: List[List]
75
- aux: bool
76
  matcher: MatcherConfig
77
 
78
 
 
72
  @dataclass
73
  class LossConfig:
74
  objective: List[List]
75
+ aux: Union[bool, float]
76
  matcher: MatcherConfig
77
 
78
 
yolo/config/hyper/default.yaml CHANGED
@@ -13,11 +13,11 @@ train:
13
  weight_decay: 0.0001
14
  loss:
15
  objective:
16
- - [BCELoss, 0.1]
17
- - [BoxLoss, 0.1]
18
- - [DFLoss, 0.1]
19
  aux:
20
- True
21
  matcher:
22
  iou: CIoU
23
  topk: 10
 
13
  weight_decay: 0.0001
14
  loss:
15
  objective:
16
+ BCELoss: 0.5
17
+ BoxLoss: 7.5
18
+ DFLoss: 1.5
19
  aux:
20
+ 0.25
21
  matcher:
22
  iou: CIoU
23
  topk: 10
yolo/utils/loss.py CHANGED
@@ -15,6 +15,7 @@ from yolo.tools.bbox_helper import (
15
  make_anchor,
16
  transform_bbox,
17
  )
 
18
 
19
 
20
  class BCELoss(nn.Module):
@@ -135,14 +136,10 @@ class YOLOLoss:
135
  anchors_box = anchors_box / self.scaler[None, :, None]
136
  return anchors_cls, anchors_box
137
 
138
- @torch.autocast("cuda" if torch.cuda.is_available() else "cpu")
139
  def __call__(self, predicts: List[Tensor], targets: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
140
  # Batch_Size x (Anchor + Class) x H x W
141
  # TODO: check datatype, why targets has a little bit error with origin version
142
- predicts, predicts_anc = self.parse_predicts(predicts[0])
143
- # TODO: Refactor this operator
144
- # targets = self.parse_targets(targets, batch_size=predicts.size(0))
145
- targets[:, :, 1:] = targets[:, :, 1:] * self.scale_up
146
 
147
  align_targets, valid_masks = self.matcher(targets, predicts)
148
  # calculate loss between with instance and predict
@@ -160,11 +157,35 @@ class YOLOLoss:
160
  ## -- DFL -- ##
161
  loss_dfl = self.dfl(predicts_anc, targets_bbox, valid_masks, box_norm, cls_norm)
162
 
163
- loss_sum = loss_iou * 0.5 + loss_dfl * 1.5 + loss_cls * 0.5
164
- return loss_sum, (loss_iou, loss_dfl, loss_cls)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
 
166
 
167
  def get_loss_function(cfg: Config) -> YOLOLoss:
168
- loss_function = YOLOLoss(cfg)
169
  logger.info("βœ… Success load loss function")
170
  return loss_function
 
15
  make_anchor,
16
  transform_bbox,
17
  )
18
+ from yolo.tools.module_helper import make_chunk
19
 
20
 
21
  class BCELoss(nn.Module):
 
136
  anchors_box = anchors_box / self.scaler[None, :, None]
137
  return anchors_cls, anchors_box
138
 
 
139
  def __call__(self, predicts: List[Tensor], targets: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
140
  # Batch_Size x (Anchor + Class) x H x W
141
  # TODO: check datatype, why targets has a little bit error with origin version
142
+ predicts, predicts_anc = self.parse_predicts(predicts)
 
 
 
143
 
144
  align_targets, valid_masks = self.matcher(targets, predicts)
145
  # calculate loss between with instance and predict
 
157
  ## -- DFL -- ##
158
  loss_dfl = self.dfl(predicts_anc, targets_bbox, valid_masks, box_norm, cls_norm)
159
 
160
+ return loss_iou, loss_dfl, loss_cls
161
+
162
+
163
+ class DualLoss:
164
+ def __init__(self, cfg: Config) -> None:
165
+ self.loss = YOLOLoss(cfg)
166
+ self.aux_rate = cfg.hyper.train.loss.aux
167
+
168
+ self.iou_rate = cfg.hyper.train.loss.objective["BoxLoss"]
169
+ self.dfl_rate = cfg.hyper.train.loss.objective["DFLoss"]
170
+ self.cls_rate = cfg.hyper.train.loss.objective["BCELoss"]
171
+
172
+ def __call__(self, predicts: List[Tensor], targets: Tensor) -> Tuple[Tensor, Tuple[Tensor]]:
173
+ targets[:, :, 1:] = targets[:, :, 1:] * self.loss.scale_up
174
+
175
+ # TODO: Need Refactor this region, make it flexible!
176
+ predicts = make_chunk(predicts[0], 2)
177
+ aux_iou, aux_dfl, aux_cls = self.loss(predicts[0], targets)
178
+ main_iou, main_dfl, main_cls = self.loss(predicts[1], targets)
179
+
180
+ loss_iou = self.iou_rate * (aux_iou * self.aux_rate + main_iou)
181
+ loss_dfl = self.dfl_rate * (aux_dfl * self.aux_rate + main_dfl)
182
+ loss_cls = self.cls_rate * (aux_cls * self.aux_rate + main_cls)
183
+
184
+ loss = (loss_iou + loss_dfl + loss_cls) / 3
185
+ return loss, (loss_iou, loss_dfl, loss_cls)
186
 
187
 
188
  def get_loss_function(cfg: Config) -> YOLOLoss:
189
+ loss_function = DualLoss(cfg)
190
  logger.info("βœ… Success load loss function")
191
  return loss_function