✨ [Add] A stream dataloader for webcam
Browse files- requirements.txt +1 -0
- yolo/config/config.py +1 -0
- yolo/config/task/dataset/demo.yaml +0 -3
- yolo/config/task/inference.yaml +2 -3
- yolo/tools/data_loader.py +100 -2
- yolo/tools/drawer.py +2 -1
- yolo/tools/solver.py +22 -10
- yolo/utils/bounding_box_utils.py +1 -1
requirements.txt
CHANGED
@@ -3,6 +3,7 @@ graphviz
|
|
3 |
hydra-core
|
4 |
loguru
|
5 |
numpy
|
|
|
6 |
Pillow
|
7 |
pytest
|
8 |
pyyaml
|
|
|
3 |
hydra-core
|
4 |
loguru
|
5 |
numpy
|
6 |
+
opencv-python
|
7 |
Pillow
|
8 |
pytest
|
9 |
pyyaml
|
yolo/config/config.py
CHANGED
@@ -113,6 +113,7 @@ class NMSConfig:
|
|
113 |
@dataclass
|
114 |
class InferenceConfig:
|
115 |
task: str
|
|
|
116 |
nms: NMSConfig
|
117 |
|
118 |
|
|
|
113 |
@dataclass
|
114 |
class InferenceConfig:
|
115 |
task: str
|
116 |
+
source: Union[str, int]
|
117 |
nms: NMSConfig
|
118 |
|
119 |
|
yolo/config/task/dataset/demo.yaml
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
path: demo
|
2 |
-
|
3 |
-
auto_download:
|
|
|
|
|
|
|
|
yolo/config/task/inference.yaml
CHANGED
@@ -1,11 +1,10 @@
|
|
1 |
task: inference
|
2 |
-
|
3 |
-
- dataset: demo
|
4 |
data:
|
5 |
batch_size: 16
|
6 |
shuffle: False
|
7 |
pin_memory: True
|
8 |
data_augment: {}
|
9 |
nms:
|
10 |
-
min_confidence: 0.
|
11 |
min_iou: 0.5
|
|
|
1 |
task: inference
|
2 |
+
source: demo/images/inference/image.png
|
|
|
3 |
data:
|
4 |
batch_size: 16
|
5 |
shuffle: False
|
6 |
pin_memory: True
|
7 |
data_augment: {}
|
8 |
nms:
|
9 |
+
min_confidence: 0.1
|
10 |
min_iou: 0.5
|
yolo/tools/data_loader.py
CHANGED
@@ -1,16 +1,19 @@
|
|
1 |
import os
|
2 |
from os import path
|
3 |
-
from
|
|
|
|
|
4 |
|
|
|
5 |
import hydra
|
6 |
import numpy as np
|
7 |
import torch
|
8 |
from loguru import logger
|
9 |
from PIL import Image
|
10 |
from rich.progress import track
|
|
|
11 |
from torch.utils.data import DataLoader, Dataset
|
12 |
from torchvision.transforms import functional as TF
|
13 |
-
from tqdm.rich import tqdm
|
14 |
|
15 |
from yolo.config.config import Config, TrainConfig
|
16 |
from yolo.tools.data_augmentation import (
|
@@ -199,12 +202,107 @@ class YoloDataLoader(DataLoader):
|
|
199 |
|
200 |
|
201 |
def create_dataloader(config: Config):
|
|
|
|
|
|
|
202 |
if config.task.dataset.auto_download:
|
203 |
prepare_dataset(config.task.dataset)
|
204 |
|
205 |
return YoloDataLoader(config)
|
206 |
|
207 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
208 |
@hydra.main(config_path="../config", config_name="config", version_base=None)
|
209 |
def main(cfg):
|
210 |
dataloader = create_dataloader(cfg)
|
|
|
1 |
import os
|
2 |
from os import path
|
3 |
+
from queue import Empty, Queue
|
4 |
+
from threading import Event, Thread
|
5 |
+
from typing import Generator, List, Optional, Tuple, Union
|
6 |
|
7 |
+
import cv2
|
8 |
import hydra
|
9 |
import numpy as np
|
10 |
import torch
|
11 |
from loguru import logger
|
12 |
from PIL import Image
|
13 |
from rich.progress import track
|
14 |
+
from torch import Tensor
|
15 |
from torch.utils.data import DataLoader, Dataset
|
16 |
from torchvision.transforms import functional as TF
|
|
|
17 |
|
18 |
from yolo.config.config import Config, TrainConfig
|
19 |
from yolo.tools.data_augmentation import (
|
|
|
202 |
|
203 |
|
204 |
def create_dataloader(config: Config):
|
205 |
+
if config.task.task == "inference":
|
206 |
+
return StreamDataLoader(config)
|
207 |
+
|
208 |
if config.task.dataset.auto_download:
|
209 |
prepare_dataset(config.task.dataset)
|
210 |
|
211 |
return YoloDataLoader(config)
|
212 |
|
213 |
|
214 |
+
class StreamDataLoader:
|
215 |
+
def __init__(self, config: Config):
|
216 |
+
self.source = config.task.source
|
217 |
+
self.running = True
|
218 |
+
self.is_stream = isinstance(self.source, int) or self.source.lower().startswith("rtmp://")
|
219 |
+
|
220 |
+
self.transform = AugmentationComposer([], config.image_size[0])
|
221 |
+
self.stop_event = Event()
|
222 |
+
|
223 |
+
if self.is_stream:
|
224 |
+
self.cap = cv2.VideoCapture(self.source)
|
225 |
+
else:
|
226 |
+
self.queue = Queue()
|
227 |
+
self.thread = Thread(target=self.load_source)
|
228 |
+
self.thread.start()
|
229 |
+
|
230 |
+
def load_source(self):
|
231 |
+
if os.path.isdir(self.source): # image folder
|
232 |
+
self.load_image_folder(self.source)
|
233 |
+
elif any(self.source.lower().endswith(ext) for ext in [".mp4", ".avi", ".mkv"]): # Video file
|
234 |
+
self.load_video_file(self.source)
|
235 |
+
else: # Single image
|
236 |
+
self.process_image(self.source)
|
237 |
+
|
238 |
+
def load_image_folder(self, folder):
|
239 |
+
for root, _, files in os.walk(folder):
|
240 |
+
for file in files:
|
241 |
+
if self.stop_event.is_set():
|
242 |
+
break
|
243 |
+
if any(file.lower().endswith(ext) for ext in [".jpg", ".jpeg", ".png", ".bmp"]):
|
244 |
+
self.process_image(os.path.join(root, file))
|
245 |
+
|
246 |
+
def process_image(self, image_path):
|
247 |
+
image = Image.open(image_path).convert("RGB")
|
248 |
+
if image is None:
|
249 |
+
raise ValueError(f"Error loading image: {image_path}")
|
250 |
+
self.process_frame(image)
|
251 |
+
|
252 |
+
def load_video_file(self, video_path):
|
253 |
+
cap = cv2.VideoCapture(video_path)
|
254 |
+
while self.running:
|
255 |
+
ret, frame = cap.read()
|
256 |
+
if not ret:
|
257 |
+
break
|
258 |
+
self.process_frame(frame)
|
259 |
+
cap.release()
|
260 |
+
|
261 |
+
def cv2_to_tensor(self, frame: np.ndarray) -> Tensor:
|
262 |
+
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
263 |
+
frame_float = frame_rgb.astype("float32") / 255.0
|
264 |
+
return torch.from_numpy(frame_float).permute(2, 0, 1)[None]
|
265 |
+
|
266 |
+
def process_frame(self, frame):
|
267 |
+
if isinstance(frame, np.ndarray):
|
268 |
+
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
269 |
+
frame = Image.fromarray(frame)
|
270 |
+
frame, _ = self.transform(frame, torch.zeros(0, 5))
|
271 |
+
frame = TF.to_tensor(frame)[None]
|
272 |
+
if not self.is_stream:
|
273 |
+
self.queue.put(frame)
|
274 |
+
else:
|
275 |
+
self.current_frame = frame
|
276 |
+
|
277 |
+
def __iter__(self) -> Generator[Tensor, None, None]:
|
278 |
+
return self
|
279 |
+
|
280 |
+
def __next__(self) -> Tensor:
|
281 |
+
if self.is_stream:
|
282 |
+
ret, frame = self.cap.read()
|
283 |
+
if not ret:
|
284 |
+
self.stop()
|
285 |
+
raise StopIteration
|
286 |
+
self.process_frame(frame)
|
287 |
+
return self.current_frame
|
288 |
+
else:
|
289 |
+
try:
|
290 |
+
frame = self.queue.get(timeout=1)
|
291 |
+
return frame
|
292 |
+
except Empty:
|
293 |
+
raise StopIteration
|
294 |
+
|
295 |
+
def stop(self):
|
296 |
+
self.running = False
|
297 |
+
if self.is_stream:
|
298 |
+
self.cap.release()
|
299 |
+
else:
|
300 |
+
self.thread.join(timeout=1)
|
301 |
+
|
302 |
+
def __len__(self):
|
303 |
+
return self.queue.qsize() if not self.is_stream else 0
|
304 |
+
|
305 |
+
|
306 |
@hydra.main(config_path="../config", config_name="config", version_base=None)
|
307 |
def main(cfg):
|
308 |
dataloader = create_dataloader(cfg)
|
yolo/tools/drawer.py
CHANGED
@@ -14,6 +14,7 @@ def draw_bboxes(
|
|
14 |
*,
|
15 |
scaled_bbox: bool = True,
|
16 |
save_path: str = "",
|
|
|
17 |
):
|
18 |
"""
|
19 |
Draw bounding boxes on an image.
|
@@ -46,7 +47,7 @@ def draw_bboxes(
|
|
46 |
draw.rectangle(shape, outline="red", width=3)
|
47 |
draw.text((x_min, y_min), str(int(class_id)), font=font, fill="blue")
|
48 |
|
49 |
-
save_image_path = os.path.join(save_path,
|
50 |
img.save(save_image_path) # Save the image with annotations
|
51 |
logger.info(f"💾 Saved visualize image at {save_image_path}")
|
52 |
return img
|
|
|
14 |
*,
|
15 |
scaled_bbox: bool = True,
|
16 |
save_path: str = "",
|
17 |
+
save_name: str = "visualize.png",
|
18 |
):
|
19 |
"""
|
20 |
Draw bounding boxes on an image.
|
|
|
47 |
draw.rectangle(shape, outline="red", width=3)
|
48 |
draw.text((x_min, y_min), str(int(class_id)), font=font, fill="blue")
|
49 |
|
50 |
+
save_image_path = os.path.join(save_path, save_name)
|
51 |
img.save(save_image_path) # Save the image with annotations
|
52 |
logger.info(f"💾 Saved visualize image at {save_image_path}")
|
53 |
return img
|
yolo/tools/solver.py
CHANGED
@@ -7,6 +7,7 @@ from torch.cuda.amp import GradScaler, autocast
|
|
7 |
|
8 |
from yolo.config.config import Config, TrainConfig
|
9 |
from yolo.model.yolo import YOLO
|
|
|
10 |
from yolo.tools.drawer import draw_bboxes
|
11 |
from yolo.tools.loss_functions import get_loss_function
|
12 |
from yolo.utils.bounding_box_utils import AnchorBoxConverter, bbox_nms
|
@@ -103,15 +104,26 @@ class ModelTester:
|
|
103 |
self.nms = cfg.task.nms
|
104 |
self.save_path = save_path
|
105 |
|
106 |
-
def solve(self, dataloader):
|
107 |
logger.info("👀 Start Inference!")
|
108 |
|
109 |
-
|
110 |
-
images
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
|
8 |
from yolo.config.config import Config, TrainConfig
|
9 |
from yolo.model.yolo import YOLO
|
10 |
+
from yolo.tools.data_loader import StreamDataLoader
|
11 |
from yolo.tools.drawer import draw_bboxes
|
12 |
from yolo.tools.loss_functions import get_loss_function
|
13 |
from yolo.utils.bounding_box_utils import AnchorBoxConverter, bbox_nms
|
|
|
104 |
self.nms = cfg.task.nms
|
105 |
self.save_path = save_path
|
106 |
|
107 |
+
def solve(self, dataloader: StreamDataLoader):
|
108 |
logger.info("👀 Start Inference!")
|
109 |
|
110 |
+
try:
|
111 |
+
for idx, images in enumerate(dataloader):
|
112 |
+
images = images.to(self.device)
|
113 |
+
with torch.no_grad():
|
114 |
+
raw_output = self.model(images)
|
115 |
+
predict, _ = self.anchor2box(raw_output[0][3:], with_logits=True)
|
116 |
+
nms_out = bbox_nms(predict, self.nms)
|
117 |
+
draw_bboxes(
|
118 |
+
images[0], nms_out[0], scaled_bbox=False, save_path=self.save_path, save_name=f"frame{idx:03d}.png"
|
119 |
+
)
|
120 |
+
except KeyboardInterrupt:
|
121 |
+
logger.error("Interrupted by user")
|
122 |
+
dataloader.stop_event.set()
|
123 |
+
dataloader.stop()
|
124 |
+
except Exception as e:
|
125 |
+
logger.error(e)
|
126 |
+
dataloader.stop_event.set()
|
127 |
+
dataloader.stop()
|
128 |
+
raise e
|
129 |
+
dataloader.stop()
|
yolo/utils/bounding_box_utils.py
CHANGED
@@ -303,7 +303,7 @@ def bbox_nms(predicts: Tensor, nms_cfg: NMSConfig):
|
|
303 |
batch_idx, *_ = torch.where(valid_mask)
|
304 |
nms_idx = batched_nms(valid_box, valid_cls, batch_idx, nms_cfg.min_iou)
|
305 |
predicts_nms = []
|
306 |
-
for idx in range(
|
307 |
instance_idx = nms_idx[idx == batch_idx[nms_idx]]
|
308 |
|
309 |
predict_nms = torch.cat([valid_cls[instance_idx][:, None], valid_box[instance_idx]], dim=-1)
|
|
|
303 |
batch_idx, *_ = torch.where(valid_mask)
|
304 |
nms_idx = batched_nms(valid_box, valid_cls, batch_idx, nms_cfg.min_iou)
|
305 |
predicts_nms = []
|
306 |
+
for idx in range(predicts.size(0)):
|
307 |
instance_idx = nms_idx[idx == batch_idx[nms_idx]]
|
308 |
|
309 |
predict_nms = torch.cat([valid_cls[instance_idx][:, None], valid_box[instance_idx]], dim=-1)
|