LaiEthanLai HenryTsui commited on
Commit
ecf6aba
Β·
unverified Β·
1 Parent(s): 860b0a5

[🎨] Add support for displaying webcam videos with predicted bounding boxes (#27)

Browse files

* βœ… [Pass] tests, skip drawing if graphviz not found

* ✨ [Add] Display processed webcam videos

---------

Co-authored-by: HenryTsui <[email protected]>

yolo/config/config.py CHANGED
@@ -108,6 +108,7 @@ class InferenceConfig:
108
  nms: NMSConfig
109
  data: DataConfig
110
  fast_inference: Optional[None]
 
111
 
112
 
113
  @dataclass
 
108
  nms: NMSConfig
109
  data: DataConfig
110
  fast_inference: Optional[None]
111
+ save_predict: bool
112
 
113
 
114
  @dataclass
yolo/config/task/inference.yaml CHANGED
@@ -7,4 +7,5 @@ data:
7
  data_augment: {}
8
  nms:
9
  min_confidence: 0.5
10
- min_iou: 0.5
 
 
7
  data_augment: {}
8
  nms:
9
  min_confidence: 0.5
10
+ min_iou: 0.5
11
+ save_predict: true
yolo/tools/drawer.py CHANGED
@@ -2,6 +2,7 @@ import os
2
  import random
3
  from typing import List, Optional, Union
4
 
 
5
  import numpy as np
6
  import torch
7
  from loguru import logger
@@ -65,9 +66,6 @@ def draw_bboxes(
65
  draw.rounded_rectangle(text_background, fill=(*color_map, 175), radius=2)
66
  draw.text((x_min, y_min), label_text, fill="white", font=font)
67
 
68
- save_image_path = os.path.join(save_path, save_name)
69
- img.save(save_image_path) # Save the image with annotations
70
- logger.info(f"πŸ’Ύ Saved visualize image at {save_image_path}")
71
  return img
72
 
73
 
 
2
  import random
3
  from typing import List, Optional, Union
4
 
5
+
6
  import numpy as np
7
  import torch
8
  from loguru import logger
 
66
  draw.rounded_rectangle(text_background, fill=(*color_map, 175), radius=2)
67
  draw.text((x_min, y_min), label_text, fill="white", font=font)
68
 
 
 
 
69
  return img
70
 
71
 
yolo/tools/solver.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  import torch
2
  from loguru import logger
3
  from torch import Tensor
@@ -106,12 +108,15 @@ class ModelTester:
106
 
107
  self.anchor2box = AnchorBoxConverter(cfg.model, cfg.image_size, device)
108
  self.nms = cfg.task.nms
 
109
  self.idx2label = cfg.class_list
110
- self.save_path = save_path
111
 
112
  def solve(self, dataloader: StreamDataLoader):
113
  logger.info("πŸ‘€ Start Inference!")
114
 
 
 
 
115
  try:
116
  for idx, images in enumerate(dataloader):
117
  images = images.to(self.device)
@@ -119,7 +124,7 @@ class ModelTester:
119
  raw_output = self.model(images)
120
  predict, _ = self.anchor2box(raw_output[0][3:], with_logits=True)
121
  nms_out = bbox_nms(predict, self.nms)
122
- draw_bboxes(
123
  images[0],
124
  nms_out[0],
125
  scaled_bbox=False,
@@ -127,6 +132,18 @@ class ModelTester:
127
  save_name=f"frame{idx:03d}.png",
128
  idx2label=self.idx2label,
129
  )
 
 
 
 
 
 
 
 
 
 
 
 
130
  except (KeyboardInterrupt, Exception) as e:
131
  dataloader.stop_event.set()
132
  dataloader.stop()
 
1
+ import os
2
+
3
  import torch
4
  from loguru import logger
5
  from torch import Tensor
 
108
 
109
  self.anchor2box = AnchorBoxConverter(cfg.model, cfg.image_size, device)
110
  self.nms = cfg.task.nms
111
+ self.save_path = save_path if getattr(cfg.task, "save_predict", True) else None
112
  self.idx2label = cfg.class_list
 
113
 
114
  def solve(self, dataloader: StreamDataLoader):
115
  logger.info("πŸ‘€ Start Inference!")
116
 
117
+ if dataloader.is_stream:
118
+ import cv2
119
+ import numpy as np
120
  try:
121
  for idx, images in enumerate(dataloader):
122
  images = images.to(self.device)
 
124
  raw_output = self.model(images)
125
  predict, _ = self.anchor2box(raw_output[0][3:], with_logits=True)
126
  nms_out = bbox_nms(predict, self.nms)
127
+ img = draw_bboxes(
128
  images[0],
129
  nms_out[0],
130
  scaled_bbox=False,
 
132
  save_name=f"frame{idx:03d}.png",
133
  idx2label=self.idx2label,
134
  )
135
+ logger.info(f"img size: {img.shape}")
136
+ if self.save_path is not None:
137
+ save_image_path = os.path.join(self.save_path, f"frame{idx:03d}.png")
138
+ img.save(save_image_path)
139
+ logger.info(f"πŸ’Ύ Saved visualize image at {save_image_path}")
140
+
141
+ if dataloader.is_stream:
142
+ img = np.array(img)
143
+ img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
144
+ cv2.imshow("Result", img)
145
+ if cv2.waitKey(1) & 0xFF == ord("q"):
146
+ break
147
  except (KeyboardInterrupt, Exception) as e:
148
  dataloader.stop_event.set()
149
  dataloader.stop()