henry000 commited on
Commit
7f8235a
·
1 Parent(s): 4b8ec68

✨ [Add] inference/predict or test mode

Browse files
yolo/config/task/inference.yaml CHANGED
@@ -8,4 +8,4 @@ data:
8
  nms:
9
  min_confidence: 0.5
10
  min_iou: 0.5
11
- # save_predict: True
 
8
  nms:
9
  min_confidence: 0.5
10
  min_iou: 0.5
11
+ save_predict: True
yolo/lazy.py CHANGED
@@ -8,7 +8,7 @@ project_root = Path(__file__).resolve().parent.parent
8
  sys.path.append(str(project_root))
9
 
10
  from yolo.config.config import Config
11
- from yolo.tools.solver import TrainModel, ValidateModel
12
  from yolo.utils.logging_utils import setup
13
 
14
 
@@ -34,6 +34,9 @@ def main(cfg: Config):
34
  case "validation":
35
  model = ValidateModel(cfg)
36
  trainer.validate(model)
 
 
 
37
 
38
 
39
  if __name__ == "__main__":
 
8
  sys.path.append(str(project_root))
9
 
10
  from yolo.config.config import Config
11
+ from yolo.tools.solver import InferenceModel, TrainModel, ValidateModel
12
  from yolo.utils.logging_utils import setup
13
 
14
 
 
34
  case "validation":
35
  model = ValidateModel(cfg)
36
  trainer.validate(model)
37
+ case "inference":
38
+ model = InferenceModel(cfg)
39
+ trainer.predict(model)
40
 
41
 
42
  if __name__ == "__main__":
yolo/tools/solver.py CHANGED
@@ -1,9 +1,15 @@
 
 
 
 
 
1
  from lightning import LightningModule
2
  from torchmetrics.detection import MeanAveragePrecision
3
 
4
  from yolo.config.config import Config
5
  from yolo.model.yolo import create_model
6
  from yolo.tools.data_loader import create_dataloader
 
7
  from yolo.tools.loss_functions import create_loss_function
8
  from yolo.utils.bounding_box_utils import create_converter, to_metrics_format
9
  from yolo.utils.model_utils import PostProccess, create_optimizer, create_scheduler
@@ -103,3 +109,46 @@ class TrainModel(ValidateModel):
103
  optimizer = create_optimizer(self.model, self.cfg.task.optimizer)
104
  scheduler = create_scheduler(optimizer, self.cfg.task.scheduler)
105
  return [optimizer], [scheduler]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ from pathlib import Path
3
+
4
+ import cv2
5
+ import numpy as np
6
  from lightning import LightningModule
7
  from torchmetrics.detection import MeanAveragePrecision
8
 
9
  from yolo.config.config import Config
10
  from yolo.model.yolo import create_model
11
  from yolo.tools.data_loader import create_dataloader
12
+ from yolo.tools.drawer import draw_bboxes
13
  from yolo.tools.loss_functions import create_loss_function
14
  from yolo.utils.bounding_box_utils import create_converter, to_metrics_format
15
  from yolo.utils.model_utils import PostProccess, create_optimizer, create_scheduler
 
109
  optimizer = create_optimizer(self.model, self.cfg.task.optimizer)
110
  scheduler = create_scheduler(optimizer, self.cfg.task.scheduler)
111
  return [optimizer], [scheduler]
112
+
113
+
114
+ class InferenceModel(BaseModel):
115
+ def __init__(self, cfg: Config):
116
+ super().__init__(cfg)
117
+ self.cfg = cfg
118
+ # TODO: Add FastModel
119
+ self.predict_loader = create_dataloader(cfg.task.data, cfg.dataset, cfg.task.task)
120
+
121
+ def setup(self, stage):
122
+ self.vec2box = create_converter(
123
+ self.cfg.model.name, self.model, self.cfg.model.anchor, self.cfg.image_size, self.device
124
+ )
125
+ self.post_process = PostProcess(self.vec2box, self.cfg.task.nms)
126
+
127
+ def predict_dataloader(self):
128
+ return self.predict_loader
129
+
130
+ def predict_step(self, batch, batch_idx):
131
+ images, rev_tensor, origin_frame = batch
132
+ predicts = self.post_process(self(images), rev_tensor)
133
+ img = draw_bboxes(origin_frame, predicts, idx2label=self.cfg.dataset.class_list)
134
+ if getattr(self.predict_loader, "is_stream", None):
135
+ fps = self._display_stream(img)
136
+ else:
137
+ fps = None
138
+ if getattr(self.cfg.task, "save_predict", None):
139
+ self._save_image(img, batch_idx)
140
+ return img, fps
141
+
142
+ def _display_stream(self, img):
143
+ img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
144
+ fps = 1 / (time.time() - self.trainer.current_epoch_start_time)
145
+ cv2.putText(img, f"FPS: {fps:.2f}", (0, 15), 0, 0.5, (100, 255, 0), 1, cv2.LINE_AA)
146
+ cv2.imshow("Prediction", img)
147
+ if cv2.waitKey(1) & 0xFF == ord("q"):
148
+ self.trainer.should_stop = True
149
+ return fps
150
+
151
+ def _save_image(self, img, batch_idx):
152
+ save_image_path = Path(self.logger.save_dir) / f"frame{batch_idx:03d}.png"
153
+ img.save(save_image_path)
154
+ print(f"💾 Saved visualize image at {save_image_path}")