henry000 commited on
Commit
e0c8580
·
1 Parent(s): 10420ed

✨ [Add] mutli-GPU on validation mode

Browse files
Files changed (2) hide show
  1. yolo/tools/solver.py +3 -1
  2. yolo/utils/model_utils.py +21 -1
yolo/tools/solver.py CHANGED
@@ -26,6 +26,7 @@ from yolo.utils.logging_utils import ProgressLogger, log_model_structure
26
  from yolo.utils.model_utils import (
27
  ExponentialMovingAverage,
28
  PostProccess,
 
29
  create_optimizer,
30
  create_scheduler,
31
  predicts_to_json,
@@ -146,7 +147,7 @@ class ModelTrainer:
146
  self.progress.finish_one_epoch(epoch_loss, epoch_idx=epoch_idx)
147
 
148
  mAPs = self.validator.solve(self.validation_dataloader, epoch_idx=epoch_idx)
149
- if self.good_epoch(mAPs):
150
  self.save_checkpoint(epoch_idx=epoch_idx)
151
  # TODO: save model if result are better than before
152
  self.progress.finish_train()
@@ -254,6 +255,7 @@ class ModelValidator:
254
  self.progress.visualize_image(images, targets, predicts, epoch_idx=epoch_idx)
255
 
256
  with open(self.json_path, "w") as f:
 
257
  json.dump(predict_json, f)
258
  if hasattr(self, "coco_gt"):
259
  self.progress.start_pycocotools()
 
26
  from yolo.utils.model_utils import (
27
  ExponentialMovingAverage,
28
  PostProccess,
29
+ collect_prediction,
30
  create_optimizer,
31
  create_scheduler,
32
  predicts_to_json,
 
147
  self.progress.finish_one_epoch(epoch_loss, epoch_idx=epoch_idx)
148
 
149
  mAPs = self.validator.solve(self.validation_dataloader, epoch_idx=epoch_idx)
150
+ if mAPs is not None and self.good_epoch(mAPs):
151
  self.save_checkpoint(epoch_idx=epoch_idx)
152
  # TODO: save model if result are better than before
153
  self.progress.finish_train()
 
255
  self.progress.visualize_image(images, targets, predicts, epoch_idx=epoch_idx)
256
 
257
  with open(self.json_path, "w") as f:
258
+ predict_json = collect_prediction(predict_json, self.progress.local_rank)
259
  json.dump(predict_json, f)
260
  if hasattr(self, "coco_gt"):
261
  self.progress.start_pycocotools()
yolo/utils/model_utils.py CHANGED
@@ -130,7 +130,7 @@ class PostProccess:
130
  self.converter = converter
131
  self.nms = nms_cfg
132
 
133
- def __call__(self, predict, rev_tensor: Optional[Tensor] = None):
134
  prediction = self.converter(predict["Main"])
135
  pred_class, _, pred_bbox = prediction[:3]
136
  pred_conf = prediction[3] if len(prediction) == 4 else None
@@ -140,6 +140,26 @@ class PostProccess:
140
  return pred_bbox
141
 
142
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
  def predicts_to_json(img_paths, predicts, rev_tensor):
144
  """
145
  TODO: function document
 
130
  self.converter = converter
131
  self.nms = nms_cfg
132
 
133
+ def __call__(self, predict, rev_tensor: Optional[Tensor] = None) -> List[Tensor]:
134
  prediction = self.converter(predict["Main"])
135
  pred_class, _, pred_bbox = prediction[:3]
136
  pred_conf = prediction[3] if len(prediction) == 4 else None
 
140
  return pred_bbox
141
 
142
 
143
+ def collect_prediction(predict_json: List, local_rank: int) -> List:
144
+ """
145
+ Collects predictions from all distributed processes and gathers them on the main process (rank 0).
146
+
147
+ Args:
148
+ predict_json (List): The prediction data (can be of any type) generated by the current process.
149
+ local_rank (int): The rank of the current process. Typically, rank 0 is the main process.
150
+
151
+ Returns:
152
+ List: The combined list of predictions from all processes if on rank 0, otherwise predict_json.
153
+ """
154
+ if dist.is_initialized() and local_rank == 0:
155
+ all_predictions = [None for _ in range(dist.get_world_size())]
156
+ dist.gather_object(predict_json, all_predictions, dst=0)
157
+ predict_json = [item for sublist in all_predictions for item in sublist]
158
+ elif dist.is_initialized():
159
+ dist.gather_object(predict_json, None, dst=0)
160
+ return predict_json
161
+
162
+
163
  def predicts_to_json(img_paths, predicts, rev_tensor):
164
  """
165
  TODO: function document