henry000 commited on
Commit
604c897
·
1 Parent(s): 32405d5

✨ [Update] back the warm up batch step

Browse files
Files changed (2) hide show
  1. yolo/tools/solver.py +9 -2
  2. yolo/utils/model_utils.py +17 -1
yolo/tools/solver.py CHANGED
@@ -27,6 +27,8 @@ class ValidateModel(BaseModel):
27
  else:
28
  self.validation_cfg = self.cfg.task.validation
29
  self.metric = MeanAveragePrecision(iou_type="bbox", box_format="xyxy")
 
 
30
 
31
  def setup(self, stage):
32
  self.vec2box = create_converter(
@@ -35,7 +37,7 @@ class ValidateModel(BaseModel):
35
  self.post_proccess = PostProccess(self.vec2box, self.validation_cfg.nms)
36
 
37
  def val_dataloader(self):
38
- return create_dataloader(self.validation_cfg.data, self.cfg.dataset, self.validation_cfg.task)
39
 
40
  def validation_step(self, batch, batch_idx):
41
  batch_size, images, targets, rev_tensor, img_paths = batch
@@ -68,15 +70,20 @@ class TrainModel(ValidateModel):
68
  def __init__(self, cfg: Config):
69
  super().__init__(cfg)
70
  self.cfg = cfg
 
71
 
72
  def setup(self, stage):
73
  super().setup(stage)
74
  self.loss_fn = create_loss_function(self.cfg, self.vec2box)
75
 
76
  def train_dataloader(self):
77
- return create_dataloader(self.cfg.task.data, self.cfg.dataset, self.cfg.task.task)
 
 
 
78
 
79
  def training_step(self, batch, batch_idx):
 
80
  batch_size, images, targets, *_ = batch
81
  predicts = self(images)
82
  aux_predicts = self.vec2box(predicts["AUX"])
 
27
  else:
28
  self.validation_cfg = self.cfg.task.validation
29
  self.metric = MeanAveragePrecision(iou_type="bbox", box_format="xyxy")
30
+ self.metric.warn_on_many_detections = False
31
+ self.val_loader = create_dataloader(self.validation_cfg.data, self.cfg.dataset, self.validation_cfg.task)
32
 
33
  def setup(self, stage):
34
  self.vec2box = create_converter(
 
37
  self.post_proccess = PostProccess(self.vec2box, self.validation_cfg.nms)
38
 
39
  def val_dataloader(self):
40
+ return self.val_loader
41
 
42
  def validation_step(self, batch, batch_idx):
43
  batch_size, images, targets, rev_tensor, img_paths = batch
 
70
  def __init__(self, cfg: Config):
71
  super().__init__(cfg)
72
  self.cfg = cfg
73
+ self.train_loader = create_dataloader(self.cfg.task.data, self.cfg.dataset, self.cfg.task.task)
74
 
75
  def setup(self, stage):
76
  super().setup(stage)
77
  self.loss_fn = create_loss_function(self.cfg, self.vec2box)
78
 
79
  def train_dataloader(self):
80
+ return self.train_loader
81
+
82
+ def on_train_epoch_start(self):
83
+ self.trainer.optimizers[0].next_epoch(len(self.train_loader))
84
 
85
  def training_step(self, batch, batch_idx):
86
+ self.trainer.optimizers[0].next_batch()
87
  batch_size, images, targets, *_ = batch
88
  predicts = self(images)
89
  aux_predicts = self.vec2box(predicts["AUX"])
yolo/utils/model_utils.py CHANGED
@@ -56,8 +56,24 @@ def create_optimizer(model: YOLO, optim_cfg: OptimizerConfig) -> Optimizer:
56
  {"params": conv_params, "momentum": 0.8},
57
  {"params": norm_params, "momentum": 0.8, "weight_decay": 0},
58
  ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  optimizer = optimizer_class(model_parameters, **optim_cfg.args)
60
- # TODO: implement batch lr schedular when warm up
61
  return optimizer
62
 
63
 
 
56
  {"params": conv_params, "momentum": 0.8},
57
  {"params": norm_params, "momentum": 0.8, "weight_decay": 0},
58
  ]
59
+
60
+ def next_epoch(self, batch_num):
61
+ self.min_lr = self.max_lr
62
+ self.max_lr = [param["lr"] for param in self.param_groups]
63
+ self.batch_num = batch_num
64
+ self.batch_idx = 0
65
+
66
+ def next_batch(self):
67
+ self.batch_idx += 1
68
+ for lr_idx, param_group in enumerate(self.param_groups):
69
+ min_lr, max_lr = self.min_lr[lr_idx], self.max_lr[lr_idx]
70
+ param_group["lr"] = min_lr + (self.batch_idx) * (max_lr - min_lr) / self.batch_num
71
+
72
+ optimizer_class.next_batch = next_batch
73
+ optimizer_class.next_epoch = next_epoch
74
+
75
  optimizer = optimizer_class(model_parameters, **optim_cfg.args)
76
+ optimizer.max_lr = [0.1, 0, 0]
77
  return optimizer
78
 
79