π [Merge] branch 'TRAIN'
Browse files
yolo/utils/model_utils.py
CHANGED
@@ -37,31 +37,31 @@ def lerp(start: float, end: float, step: Union[int, float], total: int = 1):
|
|
37 |
|
38 |
|
39 |
class EMA(Callback):
|
40 |
-
def __init__(self, decay: float = 0.9999, tau: float =
|
41 |
super().__init__()
|
42 |
logger.info(":chart_with_upwards_trend: Enable Model EMA")
|
43 |
self.decay = decay
|
44 |
self.tau = tau
|
45 |
self.step = 0
|
|
|
46 |
|
47 |
def setup(self, trainer, pl_module, stage):
|
48 |
pl_module.ema = deepcopy(pl_module.model)
|
49 |
-
self.
|
50 |
for param in pl_module.ema.parameters():
|
51 |
param.requires_grad = False
|
52 |
|
53 |
def on_validation_start(self, trainer: "Trainer", pl_module: "LightningModule"):
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
|
58 |
-
@rank_zero_only
|
59 |
@no_grad()
|
60 |
def on_train_batch_end(self, trainer: "Trainer", pl_module: "LightningModule", *args, **kwargs) -> None:
|
61 |
self.step += 1
|
62 |
decay_factor = self.decay * (1 - exp(-self.step / self.tau))
|
63 |
-
for
|
64 |
-
|
65 |
|
66 |
|
67 |
def create_optimizer(model: YOLO, optim_cfg: OptimizerConfig) -> Optimizer:
|
|
|
37 |
|
38 |
|
39 |
class EMA(Callback):
|
40 |
+
def __init__(self, decay: float = 0.9999, tau: float = 2000):
|
41 |
super().__init__()
|
42 |
logger.info(":chart_with_upwards_trend: Enable Model EMA")
|
43 |
self.decay = decay
|
44 |
self.tau = tau
|
45 |
self.step = 0
|
46 |
+
self.ema_state_dict = None
|
47 |
|
48 |
def setup(self, trainer, pl_module, stage):
|
49 |
pl_module.ema = deepcopy(pl_module.model)
|
50 |
+
self.tau /= trainer.world_size
|
51 |
for param in pl_module.ema.parameters():
|
52 |
param.requires_grad = False
|
53 |
|
54 |
def on_validation_start(self, trainer: "Trainer", pl_module: "LightningModule"):
|
55 |
+
if self.ema_state_dict is None:
|
56 |
+
self.ema_state_dict = deepcopy(pl_module.model.state_dict())
|
57 |
+
pl_module.ema.load_state_dict(self.ema_state_dict)
|
58 |
|
|
|
59 |
@no_grad()
|
60 |
def on_train_batch_end(self, trainer: "Trainer", pl_module: "LightningModule", *args, **kwargs) -> None:
|
61 |
self.step += 1
|
62 |
decay_factor = self.decay * (1 - exp(-self.step / self.tau))
|
63 |
+
for key, param in pl_module.model.state_dict().items():
|
64 |
+
self.ema_state_dict[key] = lerp(param.detach(), self.ema_state_dict[key], decay_factor)
|
65 |
|
66 |
|
67 |
def create_optimizer(model: YOLO, optim_cfg: OptimizerConfig) -> Optimizer:
|