:bug: [Fix] a bug with drawing picture after merge
Browse files- yolo/config/task/inference.yaml +1 -1
- yolo/tools/drawer.py +2 -4
- yolo/tools/solver.py +14 -17
yolo/config/task/inference.yaml
CHANGED
@@ -8,4 +8,4 @@ data:
|
|
8 |
nms:
|
9 |
min_confidence: 0.5
|
10 |
min_iou: 0.5
|
11 |
-
save_predict:
|
|
|
8 |
nms:
|
9 |
min_confidence: 0.5
|
10 |
min_iou: 0.5
|
11 |
+
# save_predict: True
|
yolo/tools/drawer.py
CHANGED
@@ -13,8 +13,6 @@ def draw_bboxes(
|
|
13 |
img: Union[Image.Image, torch.Tensor],
|
14 |
bboxes: List[List[Union[int, float]]],
|
15 |
*,
|
16 |
-
save_path: str = "",
|
17 |
-
save_name: str = "visualize.png",
|
18 |
idx2label: Optional[list],
|
19 |
):
|
20 |
"""
|
@@ -114,6 +112,6 @@ def draw_model(*, model_cfg=None, model=None, v7_base=False):
|
|
114 |
dot.edge(str(idx), str(jdx))
|
115 |
try:
|
116 |
dot.render("Model-arch", format="png", cleanup=True)
|
|
|
117 |
except:
|
118 |
-
logger.
|
119 |
-
logger.info("🎨 Drawing Model Architecture at Model-arch.png")
|
|
|
13 |
img: Union[Image.Image, torch.Tensor],
|
14 |
bboxes: List[List[Union[int, float]]],
|
15 |
*,
|
|
|
|
|
16 |
idx2label: Optional[list],
|
17 |
):
|
18 |
"""
|
|
|
112 |
dot.edge(str(idx), str(jdx))
|
113 |
try:
|
114 |
dot.render("Model-arch", format="png", cleanup=True)
|
115 |
+
logger.info("🎨 Drawing Model Architecture at Model-arch.png")
|
116 |
except:
|
117 |
+
logger.warning("⚠️ Could not find graphviz backend, continue without drawing the model architecture")
|
|
yolo/tools/solver.py
CHANGED
@@ -108,7 +108,8 @@ class ModelTester:
|
|
108 |
self.progress = ProgressTracker(cfg, save_path, cfg.use_wandb)
|
109 |
|
110 |
self.nms = cfg.task.nms
|
111 |
-
self.save_path = save_path
|
|
|
112 |
self.idx2label = cfg.class_list
|
113 |
|
114 |
def solve(self, dataloader: StreamDataLoader):
|
@@ -124,27 +125,23 @@ class ModelTester:
|
|
124 |
images = images.to(self.device)
|
125 |
with torch.no_grad():
|
126 |
predicts = self.model(images)
|
127 |
-
|
128 |
nms_out = bbox_nms(predicts[0], predicts[2], self.nms)
|
129 |
-
draw_bboxes(
|
130 |
-
images[0],
|
131 |
-
nms_out[0],
|
132 |
-
save_path=self.save_path,
|
133 |
-
save_name=f"frame{idx:03d}.png",
|
134 |
-
idx2label=self.idx2label,
|
135 |
-
)
|
136 |
-
logger.info(f"img size: {img.shape}")
|
137 |
-
if self.save_path is not None:
|
138 |
-
save_image_path = os.path.join(self.save_path, f"frame{idx:03d}.png")
|
139 |
-
img.save(save_image_path)
|
140 |
-
logger.info(f"💾 Saved visualize image at {save_image_path}")
|
141 |
|
142 |
if dataloader.is_stream:
|
143 |
-
img = np.array(img)
|
144 |
-
|
145 |
-
cv2.imshow("Result", img)
|
146 |
if cv2.waitKey(1) & 0xFF == ord("q"):
|
147 |
break
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
148 |
except (KeyboardInterrupt, Exception) as e:
|
149 |
dataloader.stop_event.set()
|
150 |
dataloader.stop()
|
|
|
108 |
self.progress = ProgressTracker(cfg, save_path, cfg.use_wandb)
|
109 |
|
110 |
self.nms = cfg.task.nms
|
111 |
+
self.save_path = save_path
|
112 |
+
self.save_predict = getattr(cfg.task, "save_predict", None)
|
113 |
self.idx2label = cfg.class_list
|
114 |
|
115 |
def solve(self, dataloader: StreamDataLoader):
|
|
|
125 |
images = images.to(self.device)
|
126 |
with torch.no_grad():
|
127 |
predicts = self.model(images)
|
128 |
+
predicts = self.vec2box(predicts["Main"])
|
129 |
nms_out = bbox_nms(predicts[0], predicts[2], self.nms)
|
130 |
+
img = draw_bboxes(images[0], nms_out[0], idx2label=self.idx2label)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
131 |
|
132 |
if dataloader.is_stream:
|
133 |
+
img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
|
134 |
+
cv2.imshow("Prediction", img)
|
|
|
135 |
if cv2.waitKey(1) & 0xFF == ord("q"):
|
136 |
break
|
137 |
+
if not self.save_predict:
|
138 |
+
continue
|
139 |
+
|
140 |
+
if self.save_predict == False:
|
141 |
+
save_image_path = os.path.join(self.save_path, f"frame{idx:03d}.png")
|
142 |
+
img.save(save_image_path)
|
143 |
+
logger.info(f"💾 Saved visualize image at {save_image_path}")
|
144 |
+
|
145 |
except (KeyboardInterrupt, Exception) as e:
|
146 |
dataloader.stop_event.set()
|
147 |
dataloader.stop()
|