π¨ [Add] Dual Loss for yolov9
Browse files- yolo/config/config.py +1 -1
- yolo/config/hyper/default.yaml +4 -4
- yolo/utils/loss.py +29 -8
yolo/config/config.py
CHANGED
@@ -72,7 +72,7 @@ class MatcherConfig:
|
|
72 |
@dataclass
|
73 |
class LossConfig:
|
74 |
objective: List[List]
|
75 |
-
aux: bool
|
76 |
matcher: MatcherConfig
|
77 |
|
78 |
|
|
|
72 |
@dataclass
|
73 |
class LossConfig:
|
74 |
objective: List[List]
|
75 |
+
aux: Union[bool, float]
|
76 |
matcher: MatcherConfig
|
77 |
|
78 |
|
yolo/config/hyper/default.yaml
CHANGED
@@ -13,11 +13,11 @@ train:
|
|
13 |
weight_decay: 0.0001
|
14 |
loss:
|
15 |
objective:
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
aux:
|
20 |
-
|
21 |
matcher:
|
22 |
iou: CIoU
|
23 |
topk: 10
|
|
|
13 |
weight_decay: 0.0001
|
14 |
loss:
|
15 |
objective:
|
16 |
+
BCELoss: 0.5
|
17 |
+
BoxLoss: 7.5
|
18 |
+
DFLoss: 1.5
|
19 |
aux:
|
20 |
+
0.25
|
21 |
matcher:
|
22 |
iou: CIoU
|
23 |
topk: 10
|
yolo/utils/loss.py
CHANGED
@@ -15,6 +15,7 @@ from yolo.tools.bbox_helper import (
|
|
15 |
make_anchor,
|
16 |
transform_bbox,
|
17 |
)
|
|
|
18 |
|
19 |
|
20 |
class BCELoss(nn.Module):
|
@@ -135,14 +136,10 @@ class YOLOLoss:
|
|
135 |
anchors_box = anchors_box / self.scaler[None, :, None]
|
136 |
return anchors_cls, anchors_box
|
137 |
|
138 |
-
@torch.autocast("cuda" if torch.cuda.is_available() else "cpu")
|
139 |
def __call__(self, predicts: List[Tensor], targets: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
|
140 |
# Batch_Size x (Anchor + Class) x H x W
|
141 |
# TODO: check datatype, why targets has a little bit error with origin version
|
142 |
-
predicts, predicts_anc = self.parse_predicts(predicts
|
143 |
-
# TODO: Refactor this operator
|
144 |
-
# targets = self.parse_targets(targets, batch_size=predicts.size(0))
|
145 |
-
targets[:, :, 1:] = targets[:, :, 1:] * self.scale_up
|
146 |
|
147 |
align_targets, valid_masks = self.matcher(targets, predicts)
|
148 |
# calculate loss between with instance and predict
|
@@ -160,11 +157,35 @@ class YOLOLoss:
|
|
160 |
## -- DFL -- ##
|
161 |
loss_dfl = self.dfl(predicts_anc, targets_bbox, valid_masks, box_norm, cls_norm)
|
162 |
|
163 |
-
|
164 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
165 |
|
166 |
|
167 |
def get_loss_function(cfg: Config) -> YOLOLoss:
|
168 |
-
loss_function =
|
169 |
logger.info("β
Success load loss function")
|
170 |
return loss_function
|
|
|
15 |
make_anchor,
|
16 |
transform_bbox,
|
17 |
)
|
18 |
+
from yolo.tools.module_helper import make_chunk
|
19 |
|
20 |
|
21 |
class BCELoss(nn.Module):
|
|
|
136 |
anchors_box = anchors_box / self.scaler[None, :, None]
|
137 |
return anchors_cls, anchors_box
|
138 |
|
|
|
139 |
def __call__(self, predicts: List[Tensor], targets: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
|
140 |
# Batch_Size x (Anchor + Class) x H x W
|
141 |
# TODO: check datatype, why targets has a little bit error with origin version
|
142 |
+
predicts, predicts_anc = self.parse_predicts(predicts)
|
|
|
|
|
|
|
143 |
|
144 |
align_targets, valid_masks = self.matcher(targets, predicts)
|
145 |
# calculate loss between with instance and predict
|
|
|
157 |
## -- DFL -- ##
|
158 |
loss_dfl = self.dfl(predicts_anc, targets_bbox, valid_masks, box_norm, cls_norm)
|
159 |
|
160 |
+
return loss_iou, loss_dfl, loss_cls
|
161 |
+
|
162 |
+
|
163 |
+
class DualLoss:
|
164 |
+
def __init__(self, cfg: Config) -> None:
|
165 |
+
self.loss = YOLOLoss(cfg)
|
166 |
+
self.aux_rate = cfg.hyper.train.loss.aux
|
167 |
+
|
168 |
+
self.iou_rate = cfg.hyper.train.loss.objective["BoxLoss"]
|
169 |
+
self.dfl_rate = cfg.hyper.train.loss.objective["DFLoss"]
|
170 |
+
self.cls_rate = cfg.hyper.train.loss.objective["BCELoss"]
|
171 |
+
|
172 |
+
def __call__(self, predicts: List[Tensor], targets: Tensor) -> Tuple[Tensor, Tuple[Tensor]]:
|
173 |
+
targets[:, :, 1:] = targets[:, :, 1:] * self.loss.scale_up
|
174 |
+
|
175 |
+
# TODO: Need Refactor this region, make it flexible!
|
176 |
+
predicts = make_chunk(predicts[0], 2)
|
177 |
+
aux_iou, aux_dfl, aux_cls = self.loss(predicts[0], targets)
|
178 |
+
main_iou, main_dfl, main_cls = self.loss(predicts[1], targets)
|
179 |
+
|
180 |
+
loss_iou = self.iou_rate * (aux_iou * self.aux_rate + main_iou)
|
181 |
+
loss_dfl = self.dfl_rate * (aux_dfl * self.aux_rate + main_dfl)
|
182 |
+
loss_cls = self.cls_rate * (aux_cls * self.aux_rate + main_cls)
|
183 |
+
|
184 |
+
loss = (loss_iou + loss_dfl + loss_cls) / 3
|
185 |
+
return loss, (loss_iou, loss_dfl, loss_cls)
|
186 |
|
187 |
|
188 |
def get_loss_function(cfg: Config) -> YOLOLoss:
|
189 |
+
loss_function = DualLoss(cfg)
|
190 |
logger.info("β
Success load loss function")
|
191 |
return loss_function
|