✨ [Add] a momentum schedule for wramup epoch
Browse files- yolo/tools/solver.py +3 -1
- yolo/utils/model_utils.py +25 -2
yolo/tools/solver.py
CHANGED
@@ -84,7 +84,9 @@ class TrainModel(ValidateModel):
|
|
84 |
return self.train_loader
|
85 |
|
86 |
def on_train_epoch_start(self):
|
87 |
-
self.trainer.optimizers[0].next_epoch(
|
|
|
|
|
88 |
self.vec2box.update(self.cfg.image_size)
|
89 |
|
90 |
def training_step(self, batch, batch_idx):
|
|
|
84 |
return self.train_loader
|
85 |
|
86 |
def on_train_epoch_start(self):
|
87 |
+
self.trainer.optimizers[0].next_epoch(
|
88 |
+
ceil(len(self.train_loader) / self.trainer.world_size), self.current_epoch
|
89 |
+
)
|
90 |
self.vec2box.update(self.cfg.image_size)
|
91 |
|
92 |
def training_step(self, batch, batch_idx):
|
yolo/utils/model_utils.py
CHANGED
@@ -15,6 +15,22 @@ from yolo.utils.bounding_box_utils import Anc2Box, Vec2Box, bbox_nms, transform_
|
|
15 |
from yolo.utils.logger import logger
|
16 |
|
17 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
class ExponentialMovingAverage:
|
19 |
def __init__(self, model: torch.nn.Module, decay: float):
|
20 |
self.model = model
|
@@ -57,9 +73,15 @@ def create_optimizer(model: YOLO, optim_cfg: OptimizerConfig) -> Optimizer:
|
|
57 |
{"params": norm_params, "momentum": 0.8, "weight_decay": 0},
|
58 |
]
|
59 |
|
60 |
-
def next_epoch(self, batch_num):
|
61 |
self.min_lr = self.max_lr
|
62 |
self.max_lr = [param["lr"] for param in self.param_groups]
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
self.batch_num = batch_num
|
64 |
self.batch_idx = 0
|
65 |
|
@@ -68,7 +90,8 @@ def create_optimizer(model: YOLO, optim_cfg: OptimizerConfig) -> Optimizer:
|
|
68 |
lr_dict = dict()
|
69 |
for lr_idx, param_group in enumerate(self.param_groups):
|
70 |
min_lr, max_lr = self.min_lr[lr_idx], self.max_lr[lr_idx]
|
71 |
-
param_group["lr"] = min_lr
|
|
|
72 |
lr_dict[f"LR/{lr_idx}"] = param_group["lr"]
|
73 |
return lr_dict
|
74 |
|
|
|
15 |
from yolo.utils.logger import logger
|
16 |
|
17 |
|
18 |
+
def lerp(start: float, end: float, step: Union[int, float], total: int = 1):
|
19 |
+
"""
|
20 |
+
Linearly interpolates between start and end values.
|
21 |
+
|
22 |
+
Parameters:
|
23 |
+
start (float): The starting value.
|
24 |
+
end (float): The ending value.
|
25 |
+
step (int): The current step in the interpolation process.
|
26 |
+
total (int): The total number of steps.
|
27 |
+
|
28 |
+
Returns:
|
29 |
+
float: The interpolated value.
|
30 |
+
"""
|
31 |
+
return start + (end - start) * step / total
|
32 |
+
|
33 |
+
|
34 |
class ExponentialMovingAverage:
|
35 |
def __init__(self, model: torch.nn.Module, decay: float):
|
36 |
self.model = model
|
|
|
73 |
{"params": norm_params, "momentum": 0.8, "weight_decay": 0},
|
74 |
]
|
75 |
|
76 |
+
def next_epoch(self, batch_num, epoch_idx):
|
77 |
self.min_lr = self.max_lr
|
78 |
self.max_lr = [param["lr"] for param in self.param_groups]
|
79 |
+
# TODO: load momentum from config instead a fix number
|
80 |
+
# 0.937: Start Momentum
|
81 |
+
# 0.8 : Normal Momemtum
|
82 |
+
# 3 : The warm up epoch num
|
83 |
+
self.min_mom = lerp(0.937, 0.8, max(epoch_idx, 3), 3)
|
84 |
+
self.max_mom = lerp(0.937, 0.8, max(epoch_idx + 1, 3), 3)
|
85 |
self.batch_num = batch_num
|
86 |
self.batch_idx = 0
|
87 |
|
|
|
90 |
lr_dict = dict()
|
91 |
for lr_idx, param_group in enumerate(self.param_groups):
|
92 |
min_lr, max_lr = self.min_lr[lr_idx], self.max_lr[lr_idx]
|
93 |
+
param_group["lr"] = lerp(min_lr, max_lr, self.batch_idx, self.batch_num)
|
94 |
+
param_group["momentum"] = lerp(self.min_mom, self.max_mom, self.batch_idx, self.batch_num)
|
95 |
lr_dict[f"LR/{lr_idx}"] = param_group["lr"]
|
96 |
return lr_dict
|
97 |
|