henry000 commited on
Commit
8ca39dc
·
1 Parent(s): 936317c

✨ [Add] A stream dataloader for webcam

Browse files
requirements.txt CHANGED
@@ -3,6 +3,7 @@ graphviz
3
  hydra-core
4
  loguru
5
  numpy
 
6
  Pillow
7
  pytest
8
  pyyaml
 
3
  hydra-core
4
  loguru
5
  numpy
6
+ opencv-python
7
  Pillow
8
  pytest
9
  pyyaml
yolo/config/config.py CHANGED
@@ -113,6 +113,7 @@ class NMSConfig:
113
  @dataclass
114
  class InferenceConfig:
115
  task: str
 
116
  nms: NMSConfig
117
 
118
 
 
113
  @dataclass
114
  class InferenceConfig:
115
  task: str
116
+ source: Union[str, int]
117
  nms: NMSConfig
118
 
119
 
yolo/config/task/dataset/demo.yaml DELETED
@@ -1,3 +0,0 @@
1
- path: demo
2
-
3
- auto_download:
 
 
 
 
yolo/config/task/inference.yaml CHANGED
@@ -1,11 +1,10 @@
1
  task: inference
2
- defaults:
3
- - dataset: demo
4
  data:
5
  batch_size: 16
6
  shuffle: False
7
  pin_memory: True
8
  data_augment: {}
9
  nms:
10
- min_confidence: 0.75
11
  min_iou: 0.5
 
1
  task: inference
2
+ source: demo/images/inference/image.png
 
3
  data:
4
  batch_size: 16
5
  shuffle: False
6
  pin_memory: True
7
  data_augment: {}
8
  nms:
9
+ min_confidence: 0.1
10
  min_iou: 0.5
yolo/tools/data_loader.py CHANGED
@@ -1,16 +1,19 @@
1
  import os
2
  from os import path
3
- from typing import List, Tuple, Union
 
 
4
 
 
5
  import hydra
6
  import numpy as np
7
  import torch
8
  from loguru import logger
9
  from PIL import Image
10
  from rich.progress import track
 
11
  from torch.utils.data import DataLoader, Dataset
12
  from torchvision.transforms import functional as TF
13
- from tqdm.rich import tqdm
14
 
15
  from yolo.config.config import Config, TrainConfig
16
  from yolo.tools.data_augmentation import (
@@ -199,12 +202,107 @@ class YoloDataLoader(DataLoader):
199
 
200
 
201
  def create_dataloader(config: Config):
 
 
 
202
  if config.task.dataset.auto_download:
203
  prepare_dataset(config.task.dataset)
204
 
205
  return YoloDataLoader(config)
206
 
207
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
  @hydra.main(config_path="../config", config_name="config", version_base=None)
209
  def main(cfg):
210
  dataloader = create_dataloader(cfg)
 
1
  import os
2
  from os import path
3
+ from queue import Empty, Queue
4
+ from threading import Event, Thread
5
+ from typing import Generator, List, Optional, Tuple, Union
6
 
7
+ import cv2
8
  import hydra
9
  import numpy as np
10
  import torch
11
  from loguru import logger
12
  from PIL import Image
13
  from rich.progress import track
14
+ from torch import Tensor
15
  from torch.utils.data import DataLoader, Dataset
16
  from torchvision.transforms import functional as TF
 
17
 
18
  from yolo.config.config import Config, TrainConfig
19
  from yolo.tools.data_augmentation import (
 
202
 
203
 
204
  def create_dataloader(config: Config):
205
+ if config.task.task == "inference":
206
+ return StreamDataLoader(config)
207
+
208
  if config.task.dataset.auto_download:
209
  prepare_dataset(config.task.dataset)
210
 
211
  return YoloDataLoader(config)
212
 
213
 
214
+ class StreamDataLoader:
215
+ def __init__(self, config: Config):
216
+ self.source = config.task.source
217
+ self.running = True
218
+ self.is_stream = isinstance(self.source, int) or self.source.lower().startswith("rtmp://")
219
+
220
+ self.transform = AugmentationComposer([], config.image_size[0])
221
+ self.stop_event = Event()
222
+
223
+ if self.is_stream:
224
+ self.cap = cv2.VideoCapture(self.source)
225
+ else:
226
+ self.queue = Queue()
227
+ self.thread = Thread(target=self.load_source)
228
+ self.thread.start()
229
+
230
+ def load_source(self):
231
+ if os.path.isdir(self.source): # image folder
232
+ self.load_image_folder(self.source)
233
+ elif any(self.source.lower().endswith(ext) for ext in [".mp4", ".avi", ".mkv"]): # Video file
234
+ self.load_video_file(self.source)
235
+ else: # Single image
236
+ self.process_image(self.source)
237
+
238
+ def load_image_folder(self, folder):
239
+ for root, _, files in os.walk(folder):
240
+ for file in files:
241
+ if self.stop_event.is_set():
242
+ break
243
+ if any(file.lower().endswith(ext) for ext in [".jpg", ".jpeg", ".png", ".bmp"]):
244
+ self.process_image(os.path.join(root, file))
245
+
246
+ def process_image(self, image_path):
247
+ image = Image.open(image_path).convert("RGB")
248
+ if image is None:
249
+ raise ValueError(f"Error loading image: {image_path}")
250
+ self.process_frame(image)
251
+
252
+ def load_video_file(self, video_path):
253
+ cap = cv2.VideoCapture(video_path)
254
+ while self.running:
255
+ ret, frame = cap.read()
256
+ if not ret:
257
+ break
258
+ self.process_frame(frame)
259
+ cap.release()
260
+
261
+ def cv2_to_tensor(self, frame: np.ndarray) -> Tensor:
262
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
263
+ frame_float = frame_rgb.astype("float32") / 255.0
264
+ return torch.from_numpy(frame_float).permute(2, 0, 1)[None]
265
+
266
+ def process_frame(self, frame):
267
+ if isinstance(frame, np.ndarray):
268
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
269
+ frame = Image.fromarray(frame)
270
+ frame, _ = self.transform(frame, torch.zeros(0, 5))
271
+ frame = TF.to_tensor(frame)[None]
272
+ if not self.is_stream:
273
+ self.queue.put(frame)
274
+ else:
275
+ self.current_frame = frame
276
+
277
+ def __iter__(self) -> Generator[Tensor, None, None]:
278
+ return self
279
+
280
+ def __next__(self) -> Tensor:
281
+ if self.is_stream:
282
+ ret, frame = self.cap.read()
283
+ if not ret:
284
+ self.stop()
285
+ raise StopIteration
286
+ self.process_frame(frame)
287
+ return self.current_frame
288
+ else:
289
+ try:
290
+ frame = self.queue.get(timeout=1)
291
+ return frame
292
+ except Empty:
293
+ raise StopIteration
294
+
295
+ def stop(self):
296
+ self.running = False
297
+ if self.is_stream:
298
+ self.cap.release()
299
+ else:
300
+ self.thread.join(timeout=1)
301
+
302
+ def __len__(self):
303
+ return self.queue.qsize() if not self.is_stream else 0
304
+
305
+
306
  @hydra.main(config_path="../config", config_name="config", version_base=None)
307
  def main(cfg):
308
  dataloader = create_dataloader(cfg)
yolo/tools/drawer.py CHANGED
@@ -14,6 +14,7 @@ def draw_bboxes(
14
  *,
15
  scaled_bbox: bool = True,
16
  save_path: str = "",
 
17
  ):
18
  """
19
  Draw bounding boxes on an image.
@@ -46,7 +47,7 @@ def draw_bboxes(
46
  draw.rectangle(shape, outline="red", width=3)
47
  draw.text((x_min, y_min), str(int(class_id)), font=font, fill="blue")
48
 
49
- save_image_path = os.path.join(save_path, "visualize.png")
50
  img.save(save_image_path) # Save the image with annotations
51
  logger.info(f"💾 Saved visualize image at {save_image_path}")
52
  return img
 
14
  *,
15
  scaled_bbox: bool = True,
16
  save_path: str = "",
17
+ save_name: str = "visualize.png",
18
  ):
19
  """
20
  Draw bounding boxes on an image.
 
47
  draw.rectangle(shape, outline="red", width=3)
48
  draw.text((x_min, y_min), str(int(class_id)), font=font, fill="blue")
49
 
50
+ save_image_path = os.path.join(save_path, save_name)
51
  img.save(save_image_path) # Save the image with annotations
52
  logger.info(f"💾 Saved visualize image at {save_image_path}")
53
  return img
yolo/tools/solver.py CHANGED
@@ -7,6 +7,7 @@ from torch.cuda.amp import GradScaler, autocast
7
 
8
  from yolo.config.config import Config, TrainConfig
9
  from yolo.model.yolo import YOLO
 
10
  from yolo.tools.drawer import draw_bboxes
11
  from yolo.tools.loss_functions import get_loss_function
12
  from yolo.utils.bounding_box_utils import AnchorBoxConverter, bbox_nms
@@ -103,15 +104,26 @@ class ModelTester:
103
  self.nms = cfg.task.nms
104
  self.save_path = save_path
105
 
106
- def solve(self, dataloader):
107
  logger.info("👀 Start Inference!")
108
 
109
- for images, _ in dataloader:
110
- images = images.to(self.device)
111
- with torch.no_grad():
112
- raw_output = self.model(images)
113
- predict, _ = self.anchor2box(raw_output[0][3:], with_logits=True)
114
-
115
- nms_out = bbox_nms(predict, self.nms)
116
- for image, bbox in zip(images, nms_out):
117
- draw_bboxes(image, bbox, scaled_bbox=False, save_path=self.save_path)
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
  from yolo.config.config import Config, TrainConfig
9
  from yolo.model.yolo import YOLO
10
+ from yolo.tools.data_loader import StreamDataLoader
11
  from yolo.tools.drawer import draw_bboxes
12
  from yolo.tools.loss_functions import get_loss_function
13
  from yolo.utils.bounding_box_utils import AnchorBoxConverter, bbox_nms
 
104
  self.nms = cfg.task.nms
105
  self.save_path = save_path
106
 
107
+ def solve(self, dataloader: StreamDataLoader):
108
  logger.info("👀 Start Inference!")
109
 
110
+ try:
111
+ for idx, images in enumerate(dataloader):
112
+ images = images.to(self.device)
113
+ with torch.no_grad():
114
+ raw_output = self.model(images)
115
+ predict, _ = self.anchor2box(raw_output[0][3:], with_logits=True)
116
+ nms_out = bbox_nms(predict, self.nms)
117
+ draw_bboxes(
118
+ images[0], nms_out[0], scaled_bbox=False, save_path=self.save_path, save_name=f"frame{idx:03d}.png"
119
+ )
120
+ except KeyboardInterrupt:
121
+ logger.error("Interrupted by user")
122
+ dataloader.stop_event.set()
123
+ dataloader.stop()
124
+ except Exception as e:
125
+ logger.error(e)
126
+ dataloader.stop_event.set()
127
+ dataloader.stop()
128
+ raise e
129
+ dataloader.stop()
yolo/utils/bounding_box_utils.py CHANGED
@@ -303,7 +303,7 @@ def bbox_nms(predicts: Tensor, nms_cfg: NMSConfig):
303
  batch_idx, *_ = torch.where(valid_mask)
304
  nms_idx = batched_nms(valid_box, valid_cls, batch_idx, nms_cfg.min_iou)
305
  predicts_nms = []
306
- for idx in range(batch_idx.max() + 1):
307
  instance_idx = nms_idx[idx == batch_idx[nms_idx]]
308
 
309
  predict_nms = torch.cat([valid_cls[instance_idx][:, None], valid_box[instance_idx]], dim=-1)
 
303
  batch_idx, *_ = torch.where(valid_mask)
304
  nms_idx = batched_nms(valid_box, valid_cls, batch_idx, nms_cfg.min_iou)
305
  predicts_nms = []
306
+ for idx in range(predicts.size(0)):
307
  instance_idx = nms_idx[idx == batch_idx[nms_idx]]
308
 
309
  predict_nms = torch.cat([valid_cls[instance_idx][:, None], valid_box[instance_idx]], dim=-1)