π [Fix] Loss scale, scale back with batch_size
Browse files- yolo/tools/loss_functions.py +5 -6
- yolo/tools/solver.py +3 -3
yolo/tools/loss_functions.py
CHANGED
@@ -124,12 +124,11 @@ class DualLoss:
|
|
124 |
aux_iou, aux_dfl, aux_cls = self.loss(aux_predicts, targets)
|
125 |
main_iou, main_dfl, main_cls = self.loss(main_predicts, targets)
|
126 |
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
loss_sum = sum(list(loss_dict.values())) / len(loss_dict)
|
133 |
return loss_sum, loss_dict
|
134 |
|
135 |
|
|
|
124 |
aux_iou, aux_dfl, aux_cls = self.loss(aux_predicts, targets)
|
125 |
main_iou, main_dfl, main_cls = self.loss(main_predicts, targets)
|
126 |
|
127 |
+
BoxLoss = self.iou_rate * (aux_iou * self.aux_rate + main_iou)
|
128 |
+
DFLoss = self.dfl_rate * (aux_dfl * self.aux_rate + main_dfl)
|
129 |
+
BCELoss = self.cls_rate * (aux_cls * self.aux_rate + main_cls)
|
130 |
+
loss_sum = (BoxLoss + DFLoss + BCELoss) / 3
|
131 |
+
loss_dict = dict(BoxLoss=BoxLoss.detach(), DFLoss=DFLoss.detach(), BCELoss=BCELoss.detach())
|
|
|
132 |
return loss_sum, loss_dict
|
133 |
|
134 |
|
yolo/tools/solver.py
CHANGED
@@ -66,7 +66,7 @@ class ModelTrainer:
|
|
66 |
self.ema = None
|
67 |
self.scaler = GradScaler()
|
68 |
|
69 |
-
def train_one_batch(self, images: Tensor, targets: Tensor):
|
70 |
images, targets = images.to(self.device), targets.to(self.device)
|
71 |
self.optimizer.zero_grad()
|
72 |
|
@@ -75,7 +75,7 @@ class ModelTrainer:
|
|
75 |
aux_predicts = self.vec2box(predicts["AUX"])
|
76 |
main_predicts = self.vec2box(predicts["Main"])
|
77 |
loss, loss_item = self.loss_fn(aux_predicts, main_predicts, targets)
|
78 |
-
|
79 |
self.scaler.scale(loss).backward()
|
80 |
self.scaler.unscale_(self.optimizer)
|
81 |
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=10.0)
|
@@ -91,7 +91,7 @@ class ModelTrainer:
|
|
91 |
self.optimizer.next_epoch(len(dataloader))
|
92 |
for batch_size, images, targets, *_ in dataloader:
|
93 |
self.optimizer.next_batch()
|
94 |
-
loss_each = self.train_one_batch(images, targets)
|
95 |
|
96 |
for loss_name, loss_val in loss_each.items():
|
97 |
if self.use_ddp: # collecting loss for each batch
|
|
|
66 |
self.ema = None
|
67 |
self.scaler = GradScaler()
|
68 |
|
69 |
+
def train_one_batch(self, images: Tensor, targets: Tensor, batch_size: int):
|
70 |
images, targets = images.to(self.device), targets.to(self.device)
|
71 |
self.optimizer.zero_grad()
|
72 |
|
|
|
75 |
aux_predicts = self.vec2box(predicts["AUX"])
|
76 |
main_predicts = self.vec2box(predicts["Main"])
|
77 |
loss, loss_item = self.loss_fn(aux_predicts, main_predicts, targets)
|
78 |
+
loss *= batch_size
|
79 |
self.scaler.scale(loss).backward()
|
80 |
self.scaler.unscale_(self.optimizer)
|
81 |
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=10.0)
|
|
|
91 |
self.optimizer.next_epoch(len(dataloader))
|
92 |
for batch_size, images, targets, *_ in dataloader:
|
93 |
self.optimizer.next_batch()
|
94 |
+
loss_each = self.train_one_batch(images, targets, batch_size)
|
95 |
|
96 |
for loss_name, loss_val in loss_each.items():
|
97 |
if self.use_ddp: # collecting loss for each batch
|