henry000 commited on
Commit
584d5bd
·
1 Parent(s): 669657d

🔨 [Update] Loss function and get_loss_func

Browse files
Files changed (2) hide show
  1. yolo/tools/trainer.py +1 -1
  2. yolo/utils/loss.py +8 -6
yolo/tools/trainer.py CHANGED
@@ -18,7 +18,7 @@ class Trainer:
18
  self.device = device
19
  self.optimizer = get_optimizer(model.parameters(), train_cfg.optimizer)
20
  self.scheduler = get_scheduler(self.optimizer, train_cfg.scheduler)
21
- self.loss_fn = get_loss_function()
22
 
23
  if train_cfg.ema.get("enabled", False):
24
  self.ema = EMA(model, decay=train_cfg.ema.decay)
 
18
  self.device = device
19
  self.optimizer = get_optimizer(model.parameters(), train_cfg.optimizer)
20
  self.scheduler = get_scheduler(self.optimizer, train_cfg.scheduler)
21
+ self.loss_fn = get_loss_function(cfg)
22
 
23
  if train_cfg.ema.get("enabled", False):
24
  self.ema = EMA(model, decay=train_cfg.ema.decay)
yolo/utils/loss.py CHANGED
@@ -17,10 +17,6 @@ from yolo.tools.bbox_helper import (
17
  )
18
 
19
 
20
- def get_loss_function(*args, **kwargs):
21
- raise NotImplementedError
22
-
23
-
24
  class BCELoss(nn.Module):
25
  def __init__(self) -> None:
26
  super().__init__()
@@ -162,5 +158,11 @@ class YOLOLoss:
162
  ## -- DFL -- ##
163
  loss_dfl = self.dfl(predicts_anc, targets_bbox, valid_masks, box_norm, cls_norm)
164
 
165
- logger.info("Loss IoU: {:.5f}, DFL: {:.5f}, CLS: {:.5f}", loss_iou, loss_dfl, loss_cls)
166
- return loss_iou, loss_dfl, loss_cls
 
 
 
 
 
 
 
17
  )
18
 
19
 
 
 
 
 
20
  class BCELoss(nn.Module):
21
  def __init__(self) -> None:
22
  super().__init__()
 
158
  ## -- DFL -- ##
159
  loss_dfl = self.dfl(predicts_anc, targets_bbox, valid_masks, box_norm, cls_norm)
160
 
161
+ loss_sum = loss_iou * 0.5 + loss_dfl * 1.5 + loss_cls * 0.5
162
+ return loss_sum, (loss_iou, loss_dfl, loss_cls)
163
+
164
+
165
+ def get_loss_function(cfg: Config) -> YOLOLoss:
166
+ loss_function = YOLOLoss(cfg)
167
+ logger.info("✅ Success load loss function")
168
+ return loss_function