henry000 commited on
Commit
4be6676
·
1 Parent(s): 2522f72

✨ [Add] a momentum schedule for wramup epoch

Browse files
Files changed (2) hide show
  1. yolo/tools/solver.py +3 -1
  2. yolo/utils/model_utils.py +25 -2
yolo/tools/solver.py CHANGED
@@ -84,7 +84,9 @@ class TrainModel(ValidateModel):
84
  return self.train_loader
85
 
86
  def on_train_epoch_start(self):
87
- self.trainer.optimizers[0].next_epoch(ceil(len(self.train_loader) / self.trainer.world_size))
 
 
88
  self.vec2box.update(self.cfg.image_size)
89
 
90
  def training_step(self, batch, batch_idx):
 
84
  return self.train_loader
85
 
86
  def on_train_epoch_start(self):
87
+ self.trainer.optimizers[0].next_epoch(
88
+ ceil(len(self.train_loader) / self.trainer.world_size), self.current_epoch
89
+ )
90
  self.vec2box.update(self.cfg.image_size)
91
 
92
  def training_step(self, batch, batch_idx):
yolo/utils/model_utils.py CHANGED
@@ -15,6 +15,22 @@ from yolo.utils.bounding_box_utils import Anc2Box, Vec2Box, bbox_nms, transform_
15
  from yolo.utils.logger import logger
16
 
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  class ExponentialMovingAverage:
19
  def __init__(self, model: torch.nn.Module, decay: float):
20
  self.model = model
@@ -57,9 +73,15 @@ def create_optimizer(model: YOLO, optim_cfg: OptimizerConfig) -> Optimizer:
57
  {"params": norm_params, "momentum": 0.8, "weight_decay": 0},
58
  ]
59
 
60
- def next_epoch(self, batch_num):
61
  self.min_lr = self.max_lr
62
  self.max_lr = [param["lr"] for param in self.param_groups]
 
 
 
 
 
 
63
  self.batch_num = batch_num
64
  self.batch_idx = 0
65
 
@@ -68,7 +90,8 @@ def create_optimizer(model: YOLO, optim_cfg: OptimizerConfig) -> Optimizer:
68
  lr_dict = dict()
69
  for lr_idx, param_group in enumerate(self.param_groups):
70
  min_lr, max_lr = self.min_lr[lr_idx], self.max_lr[lr_idx]
71
- param_group["lr"] = min_lr + (self.batch_idx) * (max_lr - min_lr) / self.batch_num
 
72
  lr_dict[f"LR/{lr_idx}"] = param_group["lr"]
73
  return lr_dict
74
 
 
15
  from yolo.utils.logger import logger
16
 
17
 
18
+ def lerp(start: float, end: float, step: Union[int, float], total: int = 1):
19
+ """
20
+ Linearly interpolates between start and end values.
21
+
22
+ Parameters:
23
+ start (float): The starting value.
24
+ end (float): The ending value.
25
+ step (int): The current step in the interpolation process.
26
+ total (int): The total number of steps.
27
+
28
+ Returns:
29
+ float: The interpolated value.
30
+ """
31
+ return start + (end - start) * step / total
32
+
33
+
34
  class ExponentialMovingAverage:
35
  def __init__(self, model: torch.nn.Module, decay: float):
36
  self.model = model
 
73
  {"params": norm_params, "momentum": 0.8, "weight_decay": 0},
74
  ]
75
 
76
+ def next_epoch(self, batch_num, epoch_idx):
77
  self.min_lr = self.max_lr
78
  self.max_lr = [param["lr"] for param in self.param_groups]
79
+ # TODO: load momentum from config instead a fix number
80
+ # 0.937: Start Momentum
81
+ # 0.8 : Normal Momemtum
82
+ # 3 : The warm up epoch num
83
+ self.min_mom = lerp(0.937, 0.8, max(epoch_idx, 3), 3)
84
+ self.max_mom = lerp(0.937, 0.8, max(epoch_idx + 1, 3), 3)
85
  self.batch_num = batch_num
86
  self.batch_idx = 0
87
 
 
90
  lr_dict = dict()
91
  for lr_idx, param_group in enumerate(self.param_groups):
92
  min_lr, max_lr = self.min_lr[lr_idx], self.max_lr[lr_idx]
93
+ param_group["lr"] = lerp(min_lr, max_lr, self.batch_idx, self.batch_num)
94
+ param_group["momentum"] = lerp(self.min_mom, self.max_mom, self.batch_idx, self.batch_num)
95
  lr_dict[f"LR/{lr_idx}"] = param_group["lr"]
96
  return lr_dict
97