henry000 commited on
Commit
635f41a
·
1 Parent(s): 6228522

:bug: [Fix] a bug with drawing picture after merge

Browse files
yolo/config/task/inference.yaml CHANGED
@@ -8,4 +8,4 @@ data:
8
  nms:
9
  min_confidence: 0.5
10
  min_iou: 0.5
11
- save_predict: true
 
8
  nms:
9
  min_confidence: 0.5
10
  min_iou: 0.5
11
+ # save_predict: True
yolo/tools/drawer.py CHANGED
@@ -13,8 +13,6 @@ def draw_bboxes(
13
  img: Union[Image.Image, torch.Tensor],
14
  bboxes: List[List[Union[int, float]]],
15
  *,
16
- save_path: str = "",
17
- save_name: str = "visualize.png",
18
  idx2label: Optional[list],
19
  ):
20
  """
@@ -114,6 +112,6 @@ def draw_model(*, model_cfg=None, model=None, v7_base=False):
114
  dot.edge(str(idx), str(jdx))
115
  try:
116
  dot.render("Model-arch", format="png", cleanup=True)
 
117
  except:
118
- logger.info("Warning: Could not find graphviz backend, continue without drawing the model architecture")
119
- logger.info("🎨 Drawing Model Architecture at Model-arch.png")
 
13
  img: Union[Image.Image, torch.Tensor],
14
  bboxes: List[List[Union[int, float]]],
15
  *,
 
 
16
  idx2label: Optional[list],
17
  ):
18
  """
 
112
  dot.edge(str(idx), str(jdx))
113
  try:
114
  dot.render("Model-arch", format="png", cleanup=True)
115
+ logger.info("🎨 Drawing Model Architecture at Model-arch.png")
116
  except:
117
+ logger.warning("⚠️ Could not find graphviz backend, continue without drawing the model architecture")
 
yolo/tools/solver.py CHANGED
@@ -108,7 +108,8 @@ class ModelTester:
108
  self.progress = ProgressTracker(cfg, save_path, cfg.use_wandb)
109
 
110
  self.nms = cfg.task.nms
111
- self.save_path = save_path if getattr(cfg.task, "save_predict", True) else None
 
112
  self.idx2label = cfg.class_list
113
 
114
  def solve(self, dataloader: StreamDataLoader):
@@ -124,27 +125,23 @@ class ModelTester:
124
  images = images.to(self.device)
125
  with torch.no_grad():
126
  predicts = self.model(images)
127
- predicts = self.vec2box(predicts["Main"])
128
  nms_out = bbox_nms(predicts[0], predicts[2], self.nms)
129
- draw_bboxes(
130
- images[0],
131
- nms_out[0],
132
- save_path=self.save_path,
133
- save_name=f"frame{idx:03d}.png",
134
- idx2label=self.idx2label,
135
- )
136
- logger.info(f"img size: {img.shape}")
137
- if self.save_path is not None:
138
- save_image_path = os.path.join(self.save_path, f"frame{idx:03d}.png")
139
- img.save(save_image_path)
140
- logger.info(f"💾 Saved visualize image at {save_image_path}")
141
 
142
  if dataloader.is_stream:
143
- img = np.array(img)
144
- img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
145
- cv2.imshow("Result", img)
146
  if cv2.waitKey(1) & 0xFF == ord("q"):
147
  break
 
 
 
 
 
 
 
 
148
  except (KeyboardInterrupt, Exception) as e:
149
  dataloader.stop_event.set()
150
  dataloader.stop()
 
108
  self.progress = ProgressTracker(cfg, save_path, cfg.use_wandb)
109
 
110
  self.nms = cfg.task.nms
111
+ self.save_path = save_path
112
+ self.save_predict = getattr(cfg.task, "save_predict", None)
113
  self.idx2label = cfg.class_list
114
 
115
  def solve(self, dataloader: StreamDataLoader):
 
125
  images = images.to(self.device)
126
  with torch.no_grad():
127
  predicts = self.model(images)
128
+ predicts = self.vec2box(predicts["Main"])
129
  nms_out = bbox_nms(predicts[0], predicts[2], self.nms)
130
+ img = draw_bboxes(images[0], nms_out[0], idx2label=self.idx2label)
 
 
 
 
 
 
 
 
 
 
 
131
 
132
  if dataloader.is_stream:
133
+ img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
134
+ cv2.imshow("Prediction", img)
 
135
  if cv2.waitKey(1) & 0xFF == ord("q"):
136
  break
137
+ if not self.save_predict:
138
+ continue
139
+
140
+ if self.save_predict == False:
141
+ save_image_path = os.path.join(self.save_path, f"frame{idx:03d}.png")
142
+ img.save(save_image_path)
143
+ logger.info(f"💾 Saved visualize image at {save_image_path}")
144
+
145
  except (KeyboardInterrupt, Exception) as e:
146
  dataloader.stop_event.set()
147
  dataloader.stop()