henry000 commited on
Commit
b9867bb
Β·
2 Parent(s): 22ebde1 e09ff86

πŸ”€ [Merge] branch 'DDP_BUGS' of github.com:WongKinYiu/yolov9mit into DDP_BUGS

Browse files
Files changed (2) hide show
  1. yolo/tools/data_loader.py +17 -14
  2. yolo/tools/solver.py +3 -4
yolo/tools/data_loader.py CHANGED
@@ -23,6 +23,20 @@ from yolo.utils.dataset_utils import (
23
  )
24
 
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  class YoloDataset(Dataset):
27
  def __init__(self, data_cfg: DataConfig, dataset_cfg: DatasetConfig, phase: str = "train2017"):
28
  augment_cfg = data_cfg.data_augment
@@ -32,19 +46,8 @@ class YoloDataset(Dataset):
32
  transforms = [eval(aug)(prob) for aug, prob in augment_cfg.items()]
33
  self.transform = AugmentationComposer(transforms, self.image_size)
34
  self.transform.get_more_data = self.get_more_data
35
- self.img_paths, self.bboxes = self.tensorlize(self.load_data(Path(dataset_cfg.path), phase_name))
36
-
37
- def tensorlize(self, data):
38
- img_paths, bboxes = zip(*data)
39
- max_box = max(bbox.size(0) for bbox in bboxes)
40
- padded_bbox_list = []
41
- for bbox in bboxes:
42
- padding = torch.full((max_box, 5), -1, dtype=torch.float32)
43
- padding[: bbox.size(0)] = bbox
44
- padded_bbox_list.append(padding)
45
- bboxes = torch.stack(padded_bbox_list)
46
- img_paths = np.array(img_paths)
47
- return img_paths, bboxes
48
 
49
  def load_data(self, dataset_path: Path, phase_name: str):
50
  """
@@ -147,7 +150,7 @@ class YoloDataset(Dataset):
147
  img_path, bboxes = self.img_paths[idx], self.bboxes[idx]
148
  valid_mask = bboxes[:, 0] != -1
149
  img = Image.open(img_path).convert("RGB")
150
- return img, bboxes[valid_mask], img_path
151
 
152
  def get_more_data(self, num: int = 1):
153
  indices = torch.randint(0, len(self), (num,))
 
23
  )
24
 
25
 
26
+ def tensorlize(data):
27
+ # TODO Move Tensorlize to helper
28
+ img_paths, bboxes = zip(*data)
29
+ max_box = max(bbox.size(0) for bbox in bboxes)
30
+ padded_bbox_list = []
31
+ for bbox in bboxes:
32
+ padding = torch.full((max_box, 5), -1, dtype=torch.float32)
33
+ padding[: bbox.size(0)] = bbox
34
+ padded_bbox_list.append(padding)
35
+ bboxes = np.stack(padded_bbox_list)
36
+ img_paths = np.array(img_paths)
37
+ return img_paths, bboxes
38
+
39
+
40
  class YoloDataset(Dataset):
41
  def __init__(self, data_cfg: DataConfig, dataset_cfg: DatasetConfig, phase: str = "train2017"):
42
  augment_cfg = data_cfg.data_augment
 
46
  transforms = [eval(aug)(prob) for aug, prob in augment_cfg.items()]
47
  self.transform = AugmentationComposer(transforms, self.image_size)
48
  self.transform.get_more_data = self.get_more_data
49
+ img_paths, bboxes = tensorlize(self.load_data(Path(dataset_cfg.path), phase_name))
50
+ self.img_paths, self.bboxes = img_paths, bboxes
 
 
 
 
 
 
 
 
 
 
 
51
 
52
  def load_data(self, dataset_path: Path, phase_name: str):
53
  """
 
150
  img_path, bboxes = self.img_paths[idx], self.bboxes[idx]
151
  valid_mask = bboxes[:, 0] != -1
152
  img = Image.open(img_path).convert("RGB")
153
+ return img, torch.from_numpy(bboxes[valid_mask]), img_path
154
 
155
  def get_more_data(self, num: int = 1):
156
  indices = torch.randint(0, len(self), (num,))
yolo/tools/solver.py CHANGED
@@ -147,7 +147,7 @@ class ModelTrainer:
147
  self.progress.finish_one_epoch(epoch_loss, epoch_idx=epoch_idx)
148
 
149
  mAPs = self.validator.solve(self.validation_dataloader, epoch_idx=epoch_idx)
150
- if mAPs is not None and self.good_epoch(mAPs):
151
  self.save_checkpoint(epoch_idx=epoch_idx)
152
  # TODO: save model if result are better than before
153
  self.progress.finish_train()
@@ -256,9 +256,8 @@ class ModelValidator:
256
 
257
  with open(self.json_path, "w") as f:
258
  predict_json = collect_prediction(predict_json, self.progress.local_rank)
259
- if self.progress.local_rank != 0:
260
- return
261
- json.dump(predict_json, f)
262
  if hasattr(self, "coco_gt"):
263
  self.progress.start_pycocotools()
264
  result = calculate_ap(self.coco_gt, predict_json)
 
147
  self.progress.finish_one_epoch(epoch_loss, epoch_idx=epoch_idx)
148
 
149
  mAPs = self.validator.solve(self.validation_dataloader, epoch_idx=epoch_idx)
150
+ if self.good_epoch(mAPs):
151
  self.save_checkpoint(epoch_idx=epoch_idx)
152
  # TODO: save model if result are better than before
153
  self.progress.finish_train()
 
256
 
257
  with open(self.json_path, "w") as f:
258
  predict_json = collect_prediction(predict_json, self.progress.local_rank)
259
+ if self.progress.local_rank == 0:
260
+ json.dump(predict_json, f)
 
261
  if hasattr(self, "coco_gt"):
262
  self.progress.start_pycocotools()
263
  result = calculate_ap(self.coco_gt, predict_json)