✨ [Add] mutli-GPU on validation mode
Browse files- yolo/tools/solver.py +3 -1
- yolo/utils/model_utils.py +21 -1
yolo/tools/solver.py
CHANGED
@@ -26,6 +26,7 @@ from yolo.utils.logging_utils import ProgressLogger, log_model_structure
|
|
26 |
from yolo.utils.model_utils import (
|
27 |
ExponentialMovingAverage,
|
28 |
PostProccess,
|
|
|
29 |
create_optimizer,
|
30 |
create_scheduler,
|
31 |
predicts_to_json,
|
@@ -146,7 +147,7 @@ class ModelTrainer:
|
|
146 |
self.progress.finish_one_epoch(epoch_loss, epoch_idx=epoch_idx)
|
147 |
|
148 |
mAPs = self.validator.solve(self.validation_dataloader, epoch_idx=epoch_idx)
|
149 |
-
if self.good_epoch(mAPs):
|
150 |
self.save_checkpoint(epoch_idx=epoch_idx)
|
151 |
# TODO: save model if result are better than before
|
152 |
self.progress.finish_train()
|
@@ -254,6 +255,7 @@ class ModelValidator:
|
|
254 |
self.progress.visualize_image(images, targets, predicts, epoch_idx=epoch_idx)
|
255 |
|
256 |
with open(self.json_path, "w") as f:
|
|
|
257 |
json.dump(predict_json, f)
|
258 |
if hasattr(self, "coco_gt"):
|
259 |
self.progress.start_pycocotools()
|
|
|
26 |
from yolo.utils.model_utils import (
|
27 |
ExponentialMovingAverage,
|
28 |
PostProccess,
|
29 |
+
collect_prediction,
|
30 |
create_optimizer,
|
31 |
create_scheduler,
|
32 |
predicts_to_json,
|
|
|
147 |
self.progress.finish_one_epoch(epoch_loss, epoch_idx=epoch_idx)
|
148 |
|
149 |
mAPs = self.validator.solve(self.validation_dataloader, epoch_idx=epoch_idx)
|
150 |
+
if mAPs is not None and self.good_epoch(mAPs):
|
151 |
self.save_checkpoint(epoch_idx=epoch_idx)
|
152 |
# TODO: save model if result are better than before
|
153 |
self.progress.finish_train()
|
|
|
255 |
self.progress.visualize_image(images, targets, predicts, epoch_idx=epoch_idx)
|
256 |
|
257 |
with open(self.json_path, "w") as f:
|
258 |
+
predict_json = collect_prediction(predict_json, self.progress.local_rank)
|
259 |
json.dump(predict_json, f)
|
260 |
if hasattr(self, "coco_gt"):
|
261 |
self.progress.start_pycocotools()
|
yolo/utils/model_utils.py
CHANGED
@@ -130,7 +130,7 @@ class PostProccess:
|
|
130 |
self.converter = converter
|
131 |
self.nms = nms_cfg
|
132 |
|
133 |
-
def __call__(self, predict, rev_tensor: Optional[Tensor] = None):
|
134 |
prediction = self.converter(predict["Main"])
|
135 |
pred_class, _, pred_bbox = prediction[:3]
|
136 |
pred_conf = prediction[3] if len(prediction) == 4 else None
|
@@ -140,6 +140,26 @@ class PostProccess:
|
|
140 |
return pred_bbox
|
141 |
|
142 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
143 |
def predicts_to_json(img_paths, predicts, rev_tensor):
|
144 |
"""
|
145 |
TODO: function document
|
|
|
130 |
self.converter = converter
|
131 |
self.nms = nms_cfg
|
132 |
|
133 |
+
def __call__(self, predict, rev_tensor: Optional[Tensor] = None) -> List[Tensor]:
|
134 |
prediction = self.converter(predict["Main"])
|
135 |
pred_class, _, pred_bbox = prediction[:3]
|
136 |
pred_conf = prediction[3] if len(prediction) == 4 else None
|
|
|
140 |
return pred_bbox
|
141 |
|
142 |
|
143 |
+
def collect_prediction(predict_json: List, local_rank: int) -> List:
|
144 |
+
"""
|
145 |
+
Collects predictions from all distributed processes and gathers them on the main process (rank 0).
|
146 |
+
|
147 |
+
Args:
|
148 |
+
predict_json (List): The prediction data (can be of any type) generated by the current process.
|
149 |
+
local_rank (int): The rank of the current process. Typically, rank 0 is the main process.
|
150 |
+
|
151 |
+
Returns:
|
152 |
+
List: The combined list of predictions from all processes if on rank 0, otherwise predict_json.
|
153 |
+
"""
|
154 |
+
if dist.is_initialized() and local_rank == 0:
|
155 |
+
all_predictions = [None for _ in range(dist.get_world_size())]
|
156 |
+
dist.gather_object(predict_json, all_predictions, dst=0)
|
157 |
+
predict_json = [item for sublist in all_predictions for item in sublist]
|
158 |
+
elif dist.is_initialized():
|
159 |
+
dist.gather_object(predict_json, None, dst=0)
|
160 |
+
return predict_json
|
161 |
+
|
162 |
+
|
163 |
def predicts_to_json(img_paths, predicts, rev_tensor):
|
164 |
"""
|
165 |
TODO: function document
|