✨ [Update] back the warm up batch step
Browse files- yolo/tools/solver.py +9 -2
- yolo/utils/model_utils.py +17 -1
yolo/tools/solver.py
CHANGED
@@ -27,6 +27,8 @@ class ValidateModel(BaseModel):
|
|
27 |
else:
|
28 |
self.validation_cfg = self.cfg.task.validation
|
29 |
self.metric = MeanAveragePrecision(iou_type="bbox", box_format="xyxy")
|
|
|
|
|
30 |
|
31 |
def setup(self, stage):
|
32 |
self.vec2box = create_converter(
|
@@ -35,7 +37,7 @@ class ValidateModel(BaseModel):
|
|
35 |
self.post_proccess = PostProccess(self.vec2box, self.validation_cfg.nms)
|
36 |
|
37 |
def val_dataloader(self):
|
38 |
-
return
|
39 |
|
40 |
def validation_step(self, batch, batch_idx):
|
41 |
batch_size, images, targets, rev_tensor, img_paths = batch
|
@@ -68,15 +70,20 @@ class TrainModel(ValidateModel):
|
|
68 |
def __init__(self, cfg: Config):
|
69 |
super().__init__(cfg)
|
70 |
self.cfg = cfg
|
|
|
71 |
|
72 |
def setup(self, stage):
|
73 |
super().setup(stage)
|
74 |
self.loss_fn = create_loss_function(self.cfg, self.vec2box)
|
75 |
|
76 |
def train_dataloader(self):
|
77 |
-
return
|
|
|
|
|
|
|
78 |
|
79 |
def training_step(self, batch, batch_idx):
|
|
|
80 |
batch_size, images, targets, *_ = batch
|
81 |
predicts = self(images)
|
82 |
aux_predicts = self.vec2box(predicts["AUX"])
|
|
|
27 |
else:
|
28 |
self.validation_cfg = self.cfg.task.validation
|
29 |
self.metric = MeanAveragePrecision(iou_type="bbox", box_format="xyxy")
|
30 |
+
self.metric.warn_on_many_detections = False
|
31 |
+
self.val_loader = create_dataloader(self.validation_cfg.data, self.cfg.dataset, self.validation_cfg.task)
|
32 |
|
33 |
def setup(self, stage):
|
34 |
self.vec2box = create_converter(
|
|
|
37 |
self.post_proccess = PostProccess(self.vec2box, self.validation_cfg.nms)
|
38 |
|
39 |
def val_dataloader(self):
|
40 |
+
return self.val_loader
|
41 |
|
42 |
def validation_step(self, batch, batch_idx):
|
43 |
batch_size, images, targets, rev_tensor, img_paths = batch
|
|
|
70 |
def __init__(self, cfg: Config):
|
71 |
super().__init__(cfg)
|
72 |
self.cfg = cfg
|
73 |
+
self.train_loader = create_dataloader(self.cfg.task.data, self.cfg.dataset, self.cfg.task.task)
|
74 |
|
75 |
def setup(self, stage):
|
76 |
super().setup(stage)
|
77 |
self.loss_fn = create_loss_function(self.cfg, self.vec2box)
|
78 |
|
79 |
def train_dataloader(self):
|
80 |
+
return self.train_loader
|
81 |
+
|
82 |
+
def on_train_epoch_start(self):
|
83 |
+
self.trainer.optimizers[0].next_epoch(len(self.train_loader))
|
84 |
|
85 |
def training_step(self, batch, batch_idx):
|
86 |
+
self.trainer.optimizers[0].next_batch()
|
87 |
batch_size, images, targets, *_ = batch
|
88 |
predicts = self(images)
|
89 |
aux_predicts = self.vec2box(predicts["AUX"])
|
yolo/utils/model_utils.py
CHANGED
@@ -56,8 +56,24 @@ def create_optimizer(model: YOLO, optim_cfg: OptimizerConfig) -> Optimizer:
|
|
56 |
{"params": conv_params, "momentum": 0.8},
|
57 |
{"params": norm_params, "momentum": 0.8, "weight_decay": 0},
|
58 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
optimizer = optimizer_class(model_parameters, **optim_cfg.args)
|
60 |
-
|
61 |
return optimizer
|
62 |
|
63 |
|
|
|
56 |
{"params": conv_params, "momentum": 0.8},
|
57 |
{"params": norm_params, "momentum": 0.8, "weight_decay": 0},
|
58 |
]
|
59 |
+
|
60 |
+
def next_epoch(self, batch_num):
|
61 |
+
self.min_lr = self.max_lr
|
62 |
+
self.max_lr = [param["lr"] for param in self.param_groups]
|
63 |
+
self.batch_num = batch_num
|
64 |
+
self.batch_idx = 0
|
65 |
+
|
66 |
+
def next_batch(self):
|
67 |
+
self.batch_idx += 1
|
68 |
+
for lr_idx, param_group in enumerate(self.param_groups):
|
69 |
+
min_lr, max_lr = self.min_lr[lr_idx], self.max_lr[lr_idx]
|
70 |
+
param_group["lr"] = min_lr + (self.batch_idx) * (max_lr - min_lr) / self.batch_num
|
71 |
+
|
72 |
+
optimizer_class.next_batch = next_batch
|
73 |
+
optimizer_class.next_epoch = next_epoch
|
74 |
+
|
75 |
optimizer = optimizer_class(model_parameters, **optim_cfg.args)
|
76 |
+
optimizer.max_lr = [0.1, 0, 0]
|
77 |
return optimizer
|
78 |
|
79 |
|