henry000 commited on
Commit
9c191b9
·
1 Parent(s): 7f8fc3e

✨ [Add] Visualize image after validation, train

Browse files
yolo/tools/solver.py CHANGED
@@ -231,7 +231,7 @@ class ModelValidator:
231
  if json_path:
232
  self.coco_gt = COCO(json_path)
233
 
234
- def solve(self, dataloader, epoch_idx=-1):
235
  # logger.info("🧪 Start Validation!")
236
  self.model.eval()
237
  predict_json, mAPs = [], defaultdict(list)
@@ -251,6 +251,7 @@ class ModelValidator:
251
 
252
  predict_json.extend(predicts_to_json(img_paths, predicts, rev_tensor))
253
  self.progress.finish_one_epoch(avg_mAPs, epoch_idx=epoch_idx)
 
254
 
255
  with open(self.json_path, "w") as f:
256
  json.dump(predict_json, f)
 
231
  if json_path:
232
  self.coco_gt = COCO(json_path)
233
 
234
+ def solve(self, dataloader, epoch_idx=1):
235
  # logger.info("🧪 Start Validation!")
236
  self.model.eval()
237
  predict_json, mAPs = [], defaultdict(list)
 
251
 
252
  predict_json.extend(predicts_to_json(img_paths, predicts, rev_tensor))
253
  self.progress.finish_one_epoch(avg_mAPs, epoch_idx=epoch_idx)
254
+ self.progress.visualize_image(images, targets, predicts, epoch_idx=epoch_idx)
255
 
256
  with open(self.json_path, "w") as f:
257
  json.dump(predict_json, f)
yolo/utils/logging_utils.py CHANGED
@@ -36,9 +36,11 @@ from rich.table import Table
36
  from torch import Tensor
37
  from torch.nn import ModuleList
38
  from torch.optim import Optimizer
 
39
 
40
  from yolo.config.config import Config, YOLOLayer
41
  from yolo.model.yolo import YOLO
 
42
  from yolo.utils.solver_utils import make_ap_table
43
 
44
 
@@ -153,6 +155,49 @@ class ProgressLogger(Progress):
153
 
154
  self.remove_task(self.batch_task)
155
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
  def start_pycocotools(self):
157
  self.batch_task = self.add_task("[green]Run pycocotools", total=1)
158
 
@@ -236,3 +281,37 @@ def validate_log_directory(cfg: Config, exp_name: str) -> Path:
236
  logger.opt(colors=True).info(f"📄 Created log folder: <u><fg #808080>{save_path}</></>")
237
  logger.add(save_path / "output.log", mode="w", backtrace=True, diagnose=True)
238
  return save_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  from torch import Tensor
37
  from torch.nn import ModuleList
38
  from torch.optim import Optimizer
39
+ from torchvision.transforms.functional import pil_to_tensor
40
 
41
  from yolo.config.config import Config, YOLOLayer
42
  from yolo.model.yolo import YOLO
43
+ from yolo.tools.drawer import draw_bboxes
44
  from yolo.utils.solver_utils import make_ap_table
45
 
46
 
 
155
 
156
  self.remove_task(self.batch_task)
157
 
158
+ def visualize_image(
159
+ self,
160
+ images: Optional[Tensor] = None,
161
+ ground_truth: Optional[Tensor] = None,
162
+ prediction: Optional[Union[List[Tensor], Tensor]] = None,
163
+ epoch_idx: int = 0,
164
+ ) -> None:
165
+ """
166
+ Upload the ground truth bounding boxes, predicted bounding boxes, and the original image to wandb or TensorBoard.
167
+
168
+ Args:
169
+ images (Optional[Tensor]): Tensor of images with shape (BZ, 3, 640, 640).
170
+ ground_truth (Optional[Tensor]): Ground truth bounding boxes with shape (BZ, N, 5) or (N, 5). Defaults to None.
171
+ prediction (prediction: Optional[Union[List[Tensor], Tensor]]): List of predicted bounding boxes with shape (N, 6) or (N, 6). Defaults to None.
172
+ epoch_idx (int): Current epoch index. Defaults to 0.
173
+ """
174
+ if images is not None:
175
+ images = images[0] if images.ndim == 4 else images
176
+ if self.use_wandb:
177
+ wandb.log({"Input Image": wandb.Image(images)}, step=epoch_idx)
178
+ if self.use_tensorboard:
179
+ self.tb_writer.add_image("Media/Input Image", images, 1)
180
+
181
+ if ground_truth is not None:
182
+ gt_boxes = ground_truth[0] if ground_truth.ndim == 3 else ground_truth
183
+ if self.use_wandb:
184
+ wandb.log(
185
+ {"Ground Truth": wandb.Image(images, boxes={"predictions": {"box_data": log_bbox(gt_boxes)}})},
186
+ step=epoch_idx,
187
+ )
188
+ if self.use_tensorboard:
189
+ self.tb_writer.add_image("Media/Ground Truth", pil_to_tensor(draw_bboxes(images, gt_boxes)), epoch_idx)
190
+
191
+ if prediction is not None:
192
+ pred_boxes = prediction[0] if isinstance(prediction, list) else prediction
193
+ if self.use_wandb:
194
+ wandb.log(
195
+ {"Prediction": wandb.Image(images, boxes={"predictions": {"box_data": log_bbox(pred_boxes)}})},
196
+ step=epoch_idx,
197
+ )
198
+ if self.use_tensorboard:
199
+ self.tb_writer.add_image("Media/Prediction", pil_to_tensor(draw_bboxes(images, pred_boxes)), epoch_idx)
200
+
201
  def start_pycocotools(self):
202
  self.batch_task = self.add_task("[green]Run pycocotools", total=1)
203
 
 
281
  logger.opt(colors=True).info(f"📄 Created log folder: <u><fg #808080>{save_path}</></>")
282
  logger.add(save_path / "output.log", mode="w", backtrace=True, diagnose=True)
283
  return save_path
284
+
285
+
286
+ def log_bbox(
287
+ bboxes: Tensor, class_list: Optional[List[str]] = None, image_size: Tuple[int, int] = (640, 640)
288
+ ) -> List[dict]:
289
+ """
290
+ Convert bounding boxes tensor to a list of dictionaries for logging, normalized by the image size.
291
+
292
+ Args:
293
+ bboxes (Tensor): Bounding boxes with shape (N, 5) or (N, 6), where each box is [class_id, x_min, y_min, x_max, y_max, (confidence)].
294
+ class_list (Optional[List[str]]): List of class names. Defaults to None.
295
+ image_size (Tuple[int, int]): The size of the image, used for normalization. Defaults to (640, 640).
296
+
297
+ Returns:
298
+ List[dict]: List of dictionaries containing normalized bounding box information.
299
+ """
300
+ bbox_list = []
301
+ scale_tensor = torch.Tensor([1, *image_size, *image_size]).to(bboxes.device)
302
+ normalized_bboxes = bboxes[:, :5] / scale_tensor
303
+ for bbox in normalized_bboxes:
304
+ class_id, x_min, y_min, x_max, y_max, *conf = [float(val) for val in bbox]
305
+ if class_id == -1:
306
+ break
307
+ bbox_entry = {
308
+ "position": {"minX": x_min, "maxX": x_max, "minY": y_min, "maxY": y_max},
309
+ "class_id": int(class_id),
310
+ }
311
+ if class_list:
312
+ bbox_entry["box_caption"] = class_list[int(class_id)]
313
+ if conf:
314
+ bbox_entry["scores"] = {"confidence": conf[0]}
315
+ bbox_list.append(bbox_entry)
316
+
317
+ return bbox_list