henry000 commited on
Commit
73b88fc
·
1 Parent(s): 66f3f62

✨ [Add] saving function when higher mAP

Browse files
Files changed (2) hide show
  1. yolo/tools/solver.py +38 -12
  2. yolo/utils/logging_utils.py +4 -12
yolo/tools/solver.py CHANGED
@@ -3,6 +3,7 @@ import os
3
  import sys
4
  import time
5
  from collections import defaultdict
 
6
 
7
  import torch
8
  from loguru import logger
@@ -43,6 +44,9 @@ class ModelTrainer:
43
  self.loss_fn = create_loss_function(cfg, vec2box)
44
  self.progress = progress
45
  self.num_epochs = cfg.task.epoch
 
 
 
46
 
47
  if not progress.quite_mode:
48
  log_model_structure(model.model)
@@ -96,9 +100,12 @@ class ModelTrainer:
96
 
97
  return total_loss
98
 
99
- def save_checkpoint(self, epoch: int, filename="checkpoint.pt"):
 
 
 
100
  checkpoint = {
101
- "epoch": epoch,
102
  "model_state_dict": self.model.state_dict(),
103
  "optimizer_state_dict": self.optimizer.state_dict(),
104
  }
@@ -106,22 +113,34 @@ class ModelTrainer:
106
  self.ema.apply_shadow()
107
  checkpoint["model_state_dict_ema"] = self.model.state_dict()
108
  self.ema.restore()
109
- torch.save(checkpoint, filename)
 
 
 
 
 
 
 
 
 
 
110
 
111
  def solve(self, dataloader: DataLoader):
112
  logger.info("🚄 Start Training!")
113
  num_epochs = self.num_epochs
114
 
115
  self.progress.start_train(num_epochs)
116
- for epoch in range(num_epochs):
117
  if self.use_ddp:
118
- dataloader.sampler.set_epoch(epoch)
119
 
120
- self.progress.start_one_epoch(len(dataloader), "Train", self.optimizer, epoch)
121
  epoch_loss = self.train_one_epoch(dataloader)
122
- self.progress.finish_one_epoch(epoch_loss, epoch)
123
 
124
- self.validator.solve(self.validation_dataloader, epoch_idx=epoch)
 
 
125
  # TODO: save model if result are better than before
126
  self.progress.finish_train()
127
 
@@ -206,7 +225,7 @@ class ModelValidator:
206
  def solve(self, dataloader, epoch_idx=-1):
207
  # logger.info("🧪 Start Validation!")
208
  self.model.eval()
209
- mAPs, predict_json = [], []
210
  self.progress.start_one_epoch(len(dataloader), task="Validate")
211
  for batch_size, images, targets, rev_tensor, img_paths in dataloader:
212
  images, targets, rev_tensor = images.to(self.device), targets.to(self.device), rev_tensor.to(self.device)
@@ -214,14 +233,21 @@ class ModelValidator:
214
  predicts = self.model(images)
215
  predicts = self.post_proccess(predicts)
216
  for idx, predict in enumerate(predicts):
217
- mAPs.append(calculate_map(predict, targets[idx]))
218
- self.progress.one_batch(Tensor(mAPs))
 
 
 
 
219
 
220
  predict_json.extend(predicts_to_json(img_paths, predicts, rev_tensor))
221
- self.progress.finish_one_epoch(Tensor(mAPs), epoch_idx=epoch_idx)
 
222
  with open(self.json_path, "w") as f:
223
  json.dump(predict_json, f)
224
 
225
  self.progress.start_pycocotools()
226
  result = calculate_ap(self.coco_gt, predict_json)
227
  self.progress.finish_pycocotools(result, epoch_idx)
 
 
 
3
  import sys
4
  import time
5
  from collections import defaultdict
6
+ from typing import Dict, Optional
7
 
8
  import torch
9
  from loguru import logger
 
44
  self.loss_fn = create_loss_function(cfg, vec2box)
45
  self.progress = progress
46
  self.num_epochs = cfg.task.epoch
47
+ self.mAPs_dict = defaultdict(list)
48
+
49
+ os.makedirs(os.path.join(self.progress.save_path, "weights"), exist_ok=True)
50
 
51
  if not progress.quite_mode:
52
  log_model_structure(model.model)
 
100
 
101
  return total_loss
102
 
103
+ def save_checkpoint(self, epoch_idx: int, file_name: Optional[str] = None):
104
+ file_name = file_name or f"E{epoch_idx:03d}.pt"
105
+ file_path = os.path.join(self.progress.save_path, "weights", file_name)
106
+
107
  checkpoint = {
108
+ "epoch": epoch_idx,
109
  "model_state_dict": self.model.state_dict(),
110
  "optimizer_state_dict": self.optimizer.state_dict(),
111
  }
 
113
  self.ema.apply_shadow()
114
  checkpoint["model_state_dict_ema"] = self.model.state_dict()
115
  self.ema.restore()
116
+
117
+ print(f"💾 success save at {file_path}")
118
+ torch.save(checkpoint, file_path)
119
+
120
+ def good_epoch(self, mAPs: Dict[str, Tensor]) -> bool:
121
+ save_flag = True
122
+ for mAP_key, mAP_val in mAPs.items():
123
+ self.mAPs_dict[mAP_key].append(mAP_val)
124
+ if mAP_val < max(self.mAPs_dict[mAP_key]):
125
+ save_flag = False
126
+ return save_flag
127
 
128
  def solve(self, dataloader: DataLoader):
129
  logger.info("🚄 Start Training!")
130
  num_epochs = self.num_epochs
131
 
132
  self.progress.start_train(num_epochs)
133
+ for epoch_idx in range(num_epochs):
134
  if self.use_ddp:
135
+ dataloader.sampler.set_epoch(epoch_idx)
136
 
137
+ self.progress.start_one_epoch(len(dataloader), "Train", self.optimizer, epoch_idx)
138
  epoch_loss = self.train_one_epoch(dataloader)
139
+ self.progress.finish_one_epoch(epoch_loss, epoch_idx=epoch_idx)
140
 
141
+ mAPs = self.validator.solve(self.validation_dataloader, epoch_idx=epoch_idx)
142
+ if self.good_epoch(mAPs):
143
+ self.save_checkpoint(epoch_idx=epoch_idx)
144
  # TODO: save model if result are better than before
145
  self.progress.finish_train()
146
 
 
225
  def solve(self, dataloader, epoch_idx=-1):
226
  # logger.info("🧪 Start Validation!")
227
  self.model.eval()
228
+ predict_json, mAPs = [], defaultdict(list)
229
  self.progress.start_one_epoch(len(dataloader), task="Validate")
230
  for batch_size, images, targets, rev_tensor, img_paths in dataloader:
231
  images, targets, rev_tensor = images.to(self.device), targets.to(self.device), rev_tensor.to(self.device)
 
233
  predicts = self.model(images)
234
  predicts = self.post_proccess(predicts)
235
  for idx, predict in enumerate(predicts):
236
+ mAP = calculate_map(predict, targets[idx])
237
+ for mAP_key, mAP_val in mAP.items():
238
+ mAPs[mAP_key].append(mAP_val)
239
+
240
+ avg_mAPs = {key: torch.mean(torch.stack(val)) for key, val in mAPs.items()}
241
+ self.progress.one_batch(avg_mAPs)
242
 
243
  predict_json.extend(predicts_to_json(img_paths, predicts, rev_tensor))
244
+ self.progress.finish_one_epoch(avg_mAPs, epoch_idx=epoch_idx)
245
+
246
  with open(self.json_path, "w") as f:
247
  json.dump(predict_json, f)
248
 
249
  self.progress.start_pycocotools()
250
  result = calculate_ap(self.coco_gt, predict_json)
251
  self.progress.finish_pycocotools(result, epoch_idx)
252
+
253
+ return avg_mAPs
yolo/utils/logging_utils.py CHANGED
@@ -100,11 +100,6 @@ class ProgressLogger(Progress):
100
  batch_descript = "|"
101
  if self.task == "Train":
102
  self.update(self.task_epoch, advance=1 / self.num_batches)
103
- elif self.task == "Validate":
104
- batch_info = {
105
- "mAP.5": batch_info.mean(dim=0)[0],
106
- "mAP.5:.95": batch_info.mean(dim=0)[1],
107
- }
108
  for info_name, info_val in batch_info.items():
109
  epoch_descript += f"{info_name: ^9}|"
110
  batch_descript += f" {info_val:2.2f} |"
@@ -114,19 +109,16 @@ class ProgressLogger(Progress):
114
 
115
  def finish_one_epoch(self, batch_info: Dict[str, Any] = None, epoch_idx: int = -1):
116
  if self.task == "Train":
117
- for loss_name in batch_info.keys():
118
- batch_info["Loss/" + loss_name] = batch_info.pop(loss_name)
119
  elif self.task == "Validate":
120
- batch_info = {
121
- "Metrics/mAP.5": batch_info.mean(dim=0)[0],
122
- "Metrics/mAP.5:.95": batch_info.mean(dim=0)[1],
123
- }
124
  if self.use_wandb:
125
  self.wandb.log(batch_info, step=epoch_idx)
126
  self.remove_task(self.batch_task)
127
 
128
  def start_pycocotools(self):
129
- self.batch_task = self.add_task("[green] run pycocotools", total=1)
130
 
131
  def finish_pycocotools(self, result, epoch_idx=-1):
132
  ap_table, ap_main = make_ap_table(result, self.ap_past_list, epoch_idx)
 
100
  batch_descript = "|"
101
  if self.task == "Train":
102
  self.update(self.task_epoch, advance=1 / self.num_batches)
 
 
 
 
 
103
  for info_name, info_val in batch_info.items():
104
  epoch_descript += f"{info_name: ^9}|"
105
  batch_descript += f" {info_val:2.2f} |"
 
109
 
110
  def finish_one_epoch(self, batch_info: Dict[str, Any] = None, epoch_idx: int = -1):
111
  if self.task == "Train":
112
+ prefix = "Loss/"
 
113
  elif self.task == "Validate":
114
+ prefix = "Metrics/"
115
+ batch_info = {f"{prefix}{key}": value for key, value in batch_info.items()}
 
 
116
  if self.use_wandb:
117
  self.wandb.log(batch_info, step=epoch_idx)
118
  self.remove_task(self.batch_task)
119
 
120
  def start_pycocotools(self):
121
+ self.batch_task = self.add_task("[green]Run pycocotools", total=1)
122
 
123
  def finish_pycocotools(self, result, epoch_idx=-1):
124
  ap_table, ap_main = make_ap_table(result, self.ap_past_list, epoch_idx)