🚑️ [Fix] broadcast of EMA and sync_dist only in val
Browse files- yolo/tools/solver.py +5 -4
- yolo/utils/model_utils.py +1 -2
yolo/tools/solver.py
CHANGED
@@ -63,9 +63,11 @@ class ValidateModel(BaseModel):
|
|
63 |
def on_validation_epoch_end(self):
|
64 |
epoch_metrics = self.metric.compute()
|
65 |
del epoch_metrics["classes"]
|
66 |
-
self.log_dict(epoch_metrics, prog_bar=True, rank_zero_only=True)
|
67 |
self.log_dict(
|
68 |
-
{"PyCOCO/AP @ .5:.95": epoch_metrics["map"], "PyCOCO/AP @ .5": epoch_metrics["map_50"]},
|
|
|
|
|
69 |
)
|
70 |
self.metric.reset()
|
71 |
|
@@ -101,10 +103,9 @@ class TrainModel(ValidateModel):
|
|
101 |
prog_bar=True,
|
102 |
on_epoch=True,
|
103 |
batch_size=batch_size,
|
104 |
-
sync_dist=True,
|
105 |
rank_zero_only=True,
|
106 |
)
|
107 |
-
self.log_dict(lr_dict, prog_bar=False, logger=True, on_epoch=False,
|
108 |
return loss * batch_size
|
109 |
|
110 |
def configure_optimizers(self):
|
|
|
63 |
def on_validation_epoch_end(self):
|
64 |
epoch_metrics = self.metric.compute()
|
65 |
del epoch_metrics["classes"]
|
66 |
+
self.log_dict(epoch_metrics, prog_bar=True, sync_dist=True, rank_zero_only=True)
|
67 |
self.log_dict(
|
68 |
+
{"PyCOCO/AP @ .5:.95": epoch_metrics["map"], "PyCOCO/AP @ .5": epoch_metrics["map_50"]},
|
69 |
+
sync_dist=True,
|
70 |
+
rank_zero_only=True,
|
71 |
)
|
72 |
self.metric.reset()
|
73 |
|
|
|
103 |
prog_bar=True,
|
104 |
on_epoch=True,
|
105 |
batch_size=batch_size,
|
|
|
106 |
rank_zero_only=True,
|
107 |
)
|
108 |
+
self.log_dict(lr_dict, prog_bar=False, logger=True, on_epoch=False, rank_zero_only=True)
|
109 |
return loss * batch_size
|
110 |
|
111 |
def configure_optimizers(self):
|
yolo/utils/model_utils.py
CHANGED
@@ -53,8 +53,7 @@ class EMA(Callback):
|
|
53 |
def on_validation_start(self, trainer: "Trainer", pl_module: "LightningModule"):
|
54 |
for param, ema_param in zip(pl_module.ema.parameters(), self.ema_parameters):
|
55 |
param.data.copy_(ema_param)
|
56 |
-
|
57 |
-
dist.broadcast(param, src=0)
|
58 |
|
59 |
@rank_zero_only
|
60 |
@no_grad()
|
|
|
53 |
def on_validation_start(self, trainer: "Trainer", pl_module: "LightningModule"):
|
54 |
for param, ema_param in zip(pl_module.ema.parameters(), self.ema_parameters):
|
55 |
param.data.copy_(ema_param)
|
56 |
+
trainer.strategy.broadcast(param)
|
|
|
57 |
|
58 |
@rank_zero_only
|
59 |
@no_grad()
|