henry000 commited on
Commit
3441a79
·
1 Parent(s): 8b1b21f

✨ [New] validation code! run my pycocotools

Browse files
yolo/lazy.py CHANGED
@@ -9,7 +9,7 @@ sys.path.append(str(project_root))
9
  from yolo.config.config import Config
10
  from yolo.model.yolo import create_model
11
  from yolo.tools.data_loader import create_dataloader
12
- from yolo.tools.solver import ModelTester, ModelTrainer
13
  from yolo.utils.bounding_box_utils import Vec2Box
14
  from yolo.utils.deploy_utils import FastModelLoader
15
  from yolo.utils.logging_utils import ProgressLogger
@@ -37,6 +37,10 @@ def main(cfg: Config):
37
  tester = ModelTester(cfg, model, vec2box, progress, device)
38
  tester.solve(dataloader)
39
 
 
 
 
 
40
 
41
  if __name__ == "__main__":
42
  main()
 
9
  from yolo.config.config import Config
10
  from yolo.model.yolo import create_model
11
  from yolo.tools.data_loader import create_dataloader
12
+ from yolo.tools.solver import ModelTester, ModelTrainer, ModelValidator
13
  from yolo.utils.bounding_box_utils import Vec2Box
14
  from yolo.utils.deploy_utils import FastModelLoader
15
  from yolo.utils.logging_utils import ProgressLogger
 
37
  tester = ModelTester(cfg, model, vec2box, progress, device)
38
  tester.solve(dataloader)
39
 
40
+ if cfg.task.task == "validation":
41
+ valider = ModelValidator(cfg.task, model, vec2box, progress, device)
42
+ valider.solve(dataloader)
43
+
44
 
45
  if __name__ == "__main__":
46
  main()
yolo/tools/solver.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import os
2
  import time
3
 
@@ -15,12 +16,14 @@ from yolo.model.yolo import YOLO
15
  from yolo.tools.data_loader import StreamDataLoader, create_dataloader
16
  from yolo.tools.drawer import draw_bboxes, draw_model
17
  from yolo.tools.loss_functions import create_loss_function
18
- from yolo.utils.bounding_box_utils import Vec2Box, bbox_nms, calculate_map
19
  from yolo.utils.logging_utils import ProgressLogger, log_model_structure
20
  from yolo.utils.model_utils import (
21
  ExponentialMovingAverage,
 
22
  create_optimizer,
23
  create_scheduler,
 
24
  )
25
 
26
 
@@ -176,32 +179,29 @@ class ModelValidator:
176
  validation_cfg: ValidationConfig,
177
  model: YOLO,
178
  vec2box: Vec2Box,
179
- device,
180
  progress: ProgressLogger,
 
181
  ):
182
  self.model = model
183
- self.vec2box = vec2box
184
  self.device = device
185
  self.progress = progress
186
 
187
- self.nms = validation_cfg.nms
 
188
 
189
  def solve(self, dataloader):
190
  # logger.info("🧪 Start Validation!")
191
  self.model.eval()
192
- # TODO: choice mAP metrics?
193
- iou_thresholds = torch.arange(0.5, 1.0, 0.05)
194
- map_all = []
195
  self.progress.start_one_epoch(len(dataloader))
196
  for images, targets, rev_tensor, img_paths in dataloader:
197
  images, targets, rev_tensor = images.to(self.device), targets.to(self.device), rev_tensor.to(self.device)
198
  with torch.no_grad():
199
  predicts = self.model(images)
200
- predicts = self.vec2box(predicts["Main"])
201
- nms_out = bbox_nms(predicts[0], predicts[2], self.nms)
202
- for idx, predict in enumerate(nms_out):
203
- map_value = calculate_map(predict, targets[idx], iou_thresholds)
204
- map_all.append(map_value[0])
205
- self.progress.one_batch(mapp=torch.Tensor(map_all).mean())
206
 
 
207
  self.progress.finish_one_epoch()
 
 
 
1
+ import json
2
  import os
3
  import time
4
 
 
16
  from yolo.tools.data_loader import StreamDataLoader, create_dataloader
17
  from yolo.tools.drawer import draw_bboxes, draw_model
18
  from yolo.tools.loss_functions import create_loss_function
19
+ from yolo.utils.bounding_box_utils import Vec2Box
20
  from yolo.utils.logging_utils import ProgressLogger, log_model_structure
21
  from yolo.utils.model_utils import (
22
  ExponentialMovingAverage,
23
+ PostProccess,
24
  create_optimizer,
25
  create_scheduler,
26
+ predicts_to_json,
27
  )
28
 
29
 
 
179
  validation_cfg: ValidationConfig,
180
  model: YOLO,
181
  vec2box: Vec2Box,
 
182
  progress: ProgressLogger,
183
+ device,
184
  ):
185
  self.model = model
 
186
  self.device = device
187
  self.progress = progress
188
 
189
+ self.post_proccess = PostProccess(vec2box, validation_cfg.nms)
190
+ self.json_path = os.path.join(self.progress.save_path, f"predict.json")
191
 
192
  def solve(self, dataloader):
193
  # logger.info("🧪 Start Validation!")
194
  self.model.eval()
195
+ predict_json = []
 
 
196
  self.progress.start_one_epoch(len(dataloader))
197
  for images, targets, rev_tensor, img_paths in dataloader:
198
  images, targets, rev_tensor = images.to(self.device), targets.to(self.device), rev_tensor.to(self.device)
199
  with torch.no_grad():
200
  predicts = self.model(images)
201
+ predicts = self.post_proccess(predicts, rev_tensor)
202
+ self.progress.one_batch()
 
 
 
 
203
 
204
+ predict_json.extend(predicts_to_json(img_paths, predicts))
205
  self.progress.finish_one_epoch()
206
+ with open(self.json_path, "w") as f:
207
+ json.dump(predict_json, f)
yolo/utils/logging_utils.py CHANGED
@@ -70,9 +70,9 @@ class ProgressLogger:
70
  self.wandb.log({f"Learning Rate/{lr_name}": lr_value}, step=epoch_idx)
71
  self.batch_task = self.progress.add_task("[green]Batches", total=num_batches)
72
 
73
- def one_batch(self, loss_dict: Dict[str, Tensor] = None, mapp=None):
74
  if loss_dict is None:
75
- self.progress.update(self.batch_task, advance=1, description=f"[green]Batches [white]{mapp:.2%}")
76
  return
77
  if self.use_wandb:
78
  for loss_name, loss_value in loss_dict.items():
 
70
  self.wandb.log({f"Learning Rate/{lr_name}": lr_value}, step=epoch_idx)
71
  self.batch_task = self.progress.add_task("[green]Batches", total=num_batches)
72
 
73
+ def one_batch(self, loss_dict: Dict[str, Tensor] = None):
74
  if loss_dict is None:
75
+ self.progress.update(self.batch_task, advance=1, description=f"[green]Validating")
76
  return
77
  if self.use_wandb:
78
  for loss_name, loss_value in loss_dict.items():