henry000 commited on
Commit
6228522
Β·
2 Parent(s): 9000502 ecf6aba

πŸ”€ [Merge] branch 'INFERENCE' of github.com:WongKinYiu/yolov9mit into INFERENCE

Browse files
yolo/config/config.py CHANGED
@@ -107,6 +107,7 @@ class InferenceConfig:
107
  nms: NMSConfig
108
  data: DataConfig
109
  fast_inference: Optional[None]
 
110
 
111
 
112
  @dataclass
 
107
  nms: NMSConfig
108
  data: DataConfig
109
  fast_inference: Optional[None]
110
+ save_predict: bool
111
 
112
 
113
  @dataclass
yolo/config/task/inference.yaml CHANGED
@@ -8,3 +8,4 @@ data:
8
  nms:
9
  min_confidence: 0.5
10
  min_iou: 0.5
 
 
8
  nms:
9
  min_confidence: 0.5
10
  min_iou: 0.5
11
+ save_predict: true
yolo/tools/drawer.py CHANGED
@@ -60,10 +60,6 @@ def draw_bboxes(
60
  draw.rounded_rectangle(text_background, fill=(*color_map, 175), radius=2)
61
  draw.text((x_min, y_min), label_text, fill="white", font=font)
62
 
63
- os.makedirs(save_path, exist_ok=True)
64
- save_image_path = os.path.join(save_path, save_name)
65
- img.save(save_image_path) # Save the image with annotations
66
- logger.info(f"πŸ’Ύ Saved visualize image at {save_image_path}")
67
  return img
68
 
69
 
 
60
  draw.rounded_rectangle(text_background, fill=(*color_map, 175), radius=2)
61
  draw.text((x_min, y_min), label_text, fill="white", font=font)
62
 
 
 
 
 
63
  return img
64
 
65
 
yolo/tools/solver.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  import torch
2
  from loguru import logger
3
  from torch import Tensor
@@ -106,13 +108,17 @@ class ModelTester:
106
  self.progress = ProgressTracker(cfg, save_path, cfg.use_wandb)
107
 
108
  self.nms = cfg.task.nms
 
109
  self.idx2label = cfg.class_list
110
- self.save_path = save_path
111
 
112
  def solve(self, dataloader: StreamDataLoader):
113
  logger.info("πŸ‘€ Start Inference!")
114
  if isinstance(self.model, torch.nn.Module):
115
  self.model.eval()
 
 
 
 
116
  try:
117
  for idx, images in enumerate(dataloader):
118
  images = images.to(self.device)
@@ -127,6 +133,18 @@ class ModelTester:
127
  save_name=f"frame{idx:03d}.png",
128
  idx2label=self.idx2label,
129
  )
 
 
 
 
 
 
 
 
 
 
 
 
130
  except (KeyboardInterrupt, Exception) as e:
131
  dataloader.stop_event.set()
132
  dataloader.stop()
 
1
+ import os
2
+
3
  import torch
4
  from loguru import logger
5
  from torch import Tensor
 
108
  self.progress = ProgressTracker(cfg, save_path, cfg.use_wandb)
109
 
110
  self.nms = cfg.task.nms
111
+ self.save_path = save_path if getattr(cfg.task, "save_predict", True) else None
112
  self.idx2label = cfg.class_list
 
113
 
114
  def solve(self, dataloader: StreamDataLoader):
115
  logger.info("πŸ‘€ Start Inference!")
116
  if isinstance(self.model, torch.nn.Module):
117
  self.model.eval()
118
+
119
+ if dataloader.is_stream:
120
+ import cv2
121
+ import numpy as np
122
  try:
123
  for idx, images in enumerate(dataloader):
124
  images = images.to(self.device)
 
133
  save_name=f"frame{idx:03d}.png",
134
  idx2label=self.idx2label,
135
  )
136
+ logger.info(f"img size: {img.shape}")
137
+ if self.save_path is not None:
138
+ save_image_path = os.path.join(self.save_path, f"frame{idx:03d}.png")
139
+ img.save(save_image_path)
140
+ logger.info(f"πŸ’Ύ Saved visualize image at {save_image_path}")
141
+
142
+ if dataloader.is_stream:
143
+ img = np.array(img)
144
+ img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
145
+ cv2.imshow("Result", img)
146
+ if cv2.waitKey(1) & 0xFF == ord("q"):
147
+ break
148
  except (KeyboardInterrupt, Exception) as e:
149
  dataloader.stop_event.set()
150
  dataloader.stop()