🔨 [Update] Loss function and get_loss_func
Browse files- yolo/tools/trainer.py +1 -1
- yolo/utils/loss.py +8 -6
yolo/tools/trainer.py
CHANGED
@@ -18,7 +18,7 @@ class Trainer:
|
|
18 |
self.device = device
|
19 |
self.optimizer = get_optimizer(model.parameters(), train_cfg.optimizer)
|
20 |
self.scheduler = get_scheduler(self.optimizer, train_cfg.scheduler)
|
21 |
-
self.loss_fn = get_loss_function()
|
22 |
|
23 |
if train_cfg.ema.get("enabled", False):
|
24 |
self.ema = EMA(model, decay=train_cfg.ema.decay)
|
|
|
18 |
self.device = device
|
19 |
self.optimizer = get_optimizer(model.parameters(), train_cfg.optimizer)
|
20 |
self.scheduler = get_scheduler(self.optimizer, train_cfg.scheduler)
|
21 |
+
self.loss_fn = get_loss_function(cfg)
|
22 |
|
23 |
if train_cfg.ema.get("enabled", False):
|
24 |
self.ema = EMA(model, decay=train_cfg.ema.decay)
|
yolo/utils/loss.py
CHANGED
@@ -17,10 +17,6 @@ from yolo.tools.bbox_helper import (
|
|
17 |
)
|
18 |
|
19 |
|
20 |
-
def get_loss_function(*args, **kwargs):
|
21 |
-
raise NotImplementedError
|
22 |
-
|
23 |
-
|
24 |
class BCELoss(nn.Module):
|
25 |
def __init__(self) -> None:
|
26 |
super().__init__()
|
@@ -162,5 +158,11 @@ class YOLOLoss:
|
|
162 |
## -- DFL -- ##
|
163 |
loss_dfl = self.dfl(predicts_anc, targets_bbox, valid_masks, box_norm, cls_norm)
|
164 |
|
165 |
-
|
166 |
-
return loss_iou, loss_dfl, loss_cls
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
)
|
18 |
|
19 |
|
|
|
|
|
|
|
|
|
20 |
class BCELoss(nn.Module):
|
21 |
def __init__(self) -> None:
|
22 |
super().__init__()
|
|
|
158 |
## -- DFL -- ##
|
159 |
loss_dfl = self.dfl(predicts_anc, targets_bbox, valid_masks, box_norm, cls_norm)
|
160 |
|
161 |
+
loss_sum = loss_iou * 0.5 + loss_dfl * 1.5 + loss_cls * 0.5
|
162 |
+
return loss_sum, (loss_iou, loss_dfl, loss_cls)
|
163 |
+
|
164 |
+
|
165 |
+
def get_loss_function(cfg: Config) -> YOLOLoss:
|
166 |
+
loss_function = YOLOLoss(cfg)
|
167 |
+
logger.info("✅ Success load loss function")
|
168 |
+
return loss_function
|