henry000 commited on
Commit
67fbfa0
Β·
2 Parent(s): a80fd8c 1d404e2

πŸ”€ [Merge] branch 'TRAIN'

Browse files
Files changed (1) hide show
  1. yolo/utils/model_utils.py +8 -8
yolo/utils/model_utils.py CHANGED
@@ -37,31 +37,31 @@ def lerp(start: float, end: float, step: Union[int, float], total: int = 1):
37
 
38
 
39
  class EMA(Callback):
40
- def __init__(self, decay: float = 0.9999, tau: float = 500):
41
  super().__init__()
42
  logger.info(":chart_with_upwards_trend: Enable Model EMA")
43
  self.decay = decay
44
  self.tau = tau
45
  self.step = 0
 
46
 
47
  def setup(self, trainer, pl_module, stage):
48
  pl_module.ema = deepcopy(pl_module.model)
49
- self.ema_parameters = [param.clone().detach().to(pl_module.device) for param in pl_module.parameters()]
50
  for param in pl_module.ema.parameters():
51
  param.requires_grad = False
52
 
53
  def on_validation_start(self, trainer: "Trainer", pl_module: "LightningModule"):
54
- for param, ema_param in zip(pl_module.ema.parameters(), self.ema_parameters):
55
- param.data.copy_(ema_param)
56
- trainer.strategy.broadcast(param)
57
 
58
- @rank_zero_only
59
  @no_grad()
60
  def on_train_batch_end(self, trainer: "Trainer", pl_module: "LightningModule", *args, **kwargs) -> None:
61
  self.step += 1
62
  decay_factor = self.decay * (1 - exp(-self.step / self.tau))
63
- for param, ema_param in zip(pl_module.parameters(), self.ema_parameters):
64
- ema_param.data.copy_(lerp(param.detach(), ema_param, decay_factor))
65
 
66
 
67
  def create_optimizer(model: YOLO, optim_cfg: OptimizerConfig) -> Optimizer:
 
37
 
38
 
39
  class EMA(Callback):
40
+ def __init__(self, decay: float = 0.9999, tau: float = 2000):
41
  super().__init__()
42
  logger.info(":chart_with_upwards_trend: Enable Model EMA")
43
  self.decay = decay
44
  self.tau = tau
45
  self.step = 0
46
+ self.ema_state_dict = None
47
 
48
  def setup(self, trainer, pl_module, stage):
49
  pl_module.ema = deepcopy(pl_module.model)
50
+ self.tau /= trainer.world_size
51
  for param in pl_module.ema.parameters():
52
  param.requires_grad = False
53
 
54
  def on_validation_start(self, trainer: "Trainer", pl_module: "LightningModule"):
55
+ if self.ema_state_dict is None:
56
+ self.ema_state_dict = deepcopy(pl_module.model.state_dict())
57
+ pl_module.ema.load_state_dict(self.ema_state_dict)
58
 
 
59
  @no_grad()
60
  def on_train_batch_end(self, trainer: "Trainer", pl_module: "LightningModule", *args, **kwargs) -> None:
61
  self.step += 1
62
  decay_factor = self.decay * (1 - exp(-self.step / self.tau))
63
+ for key, param in pl_module.model.state_dict().items():
64
+ self.ema_state_dict[key] = lerp(param.detach(), self.ema_state_dict[key], decay_factor)
65
 
66
 
67
  def create_optimizer(model: YOLO, optim_cfg: OptimizerConfig) -> Optimizer: