✨ [Add] inference/predict or test mode
Browse files- yolo/config/task/inference.yaml +1 -1
- yolo/lazy.py +4 -1
- yolo/tools/solver.py +49 -0
yolo/config/task/inference.yaml
CHANGED
@@ -8,4 +8,4 @@ data:
|
|
8 |
nms:
|
9 |
min_confidence: 0.5
|
10 |
min_iou: 0.5
|
11 |
-
|
|
|
8 |
nms:
|
9 |
min_confidence: 0.5
|
10 |
min_iou: 0.5
|
11 |
+
save_predict: True
|
yolo/lazy.py
CHANGED
@@ -8,7 +8,7 @@ project_root = Path(__file__).resolve().parent.parent
|
|
8 |
sys.path.append(str(project_root))
|
9 |
|
10 |
from yolo.config.config import Config
|
11 |
-
from yolo.tools.solver import TrainModel, ValidateModel
|
12 |
from yolo.utils.logging_utils import setup
|
13 |
|
14 |
|
@@ -34,6 +34,9 @@ def main(cfg: Config):
|
|
34 |
case "validation":
|
35 |
model = ValidateModel(cfg)
|
36 |
trainer.validate(model)
|
|
|
|
|
|
|
37 |
|
38 |
|
39 |
if __name__ == "__main__":
|
|
|
8 |
sys.path.append(str(project_root))
|
9 |
|
10 |
from yolo.config.config import Config
|
11 |
+
from yolo.tools.solver import InferenceModel, TrainModel, ValidateModel
|
12 |
from yolo.utils.logging_utils import setup
|
13 |
|
14 |
|
|
|
34 |
case "validation":
|
35 |
model = ValidateModel(cfg)
|
36 |
trainer.validate(model)
|
37 |
+
case "inference":
|
38 |
+
model = InferenceModel(cfg)
|
39 |
+
trainer.predict(model)
|
40 |
|
41 |
|
42 |
if __name__ == "__main__":
|
yolo/tools/solver.py
CHANGED
@@ -1,9 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from lightning import LightningModule
|
2 |
from torchmetrics.detection import MeanAveragePrecision
|
3 |
|
4 |
from yolo.config.config import Config
|
5 |
from yolo.model.yolo import create_model
|
6 |
from yolo.tools.data_loader import create_dataloader
|
|
|
7 |
from yolo.tools.loss_functions import create_loss_function
|
8 |
from yolo.utils.bounding_box_utils import create_converter, to_metrics_format
|
9 |
from yolo.utils.model_utils import PostProccess, create_optimizer, create_scheduler
|
@@ -103,3 +109,46 @@ class TrainModel(ValidateModel):
|
|
103 |
optimizer = create_optimizer(self.model, self.cfg.task.optimizer)
|
104 |
scheduler = create_scheduler(optimizer, self.cfg.task.scheduler)
|
105 |
return [optimizer], [scheduler]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
from pathlib import Path
|
3 |
+
|
4 |
+
import cv2
|
5 |
+
import numpy as np
|
6 |
from lightning import LightningModule
|
7 |
from torchmetrics.detection import MeanAveragePrecision
|
8 |
|
9 |
from yolo.config.config import Config
|
10 |
from yolo.model.yolo import create_model
|
11 |
from yolo.tools.data_loader import create_dataloader
|
12 |
+
from yolo.tools.drawer import draw_bboxes
|
13 |
from yolo.tools.loss_functions import create_loss_function
|
14 |
from yolo.utils.bounding_box_utils import create_converter, to_metrics_format
|
15 |
from yolo.utils.model_utils import PostProccess, create_optimizer, create_scheduler
|
|
|
109 |
optimizer = create_optimizer(self.model, self.cfg.task.optimizer)
|
110 |
scheduler = create_scheduler(optimizer, self.cfg.task.scheduler)
|
111 |
return [optimizer], [scheduler]
|
112 |
+
|
113 |
+
|
114 |
+
class InferenceModel(BaseModel):
|
115 |
+
def __init__(self, cfg: Config):
|
116 |
+
super().__init__(cfg)
|
117 |
+
self.cfg = cfg
|
118 |
+
# TODO: Add FastModel
|
119 |
+
self.predict_loader = create_dataloader(cfg.task.data, cfg.dataset, cfg.task.task)
|
120 |
+
|
121 |
+
def setup(self, stage):
|
122 |
+
self.vec2box = create_converter(
|
123 |
+
self.cfg.model.name, self.model, self.cfg.model.anchor, self.cfg.image_size, self.device
|
124 |
+
)
|
125 |
+
self.post_process = PostProcess(self.vec2box, self.cfg.task.nms)
|
126 |
+
|
127 |
+
def predict_dataloader(self):
|
128 |
+
return self.predict_loader
|
129 |
+
|
130 |
+
def predict_step(self, batch, batch_idx):
|
131 |
+
images, rev_tensor, origin_frame = batch
|
132 |
+
predicts = self.post_process(self(images), rev_tensor)
|
133 |
+
img = draw_bboxes(origin_frame, predicts, idx2label=self.cfg.dataset.class_list)
|
134 |
+
if getattr(self.predict_loader, "is_stream", None):
|
135 |
+
fps = self._display_stream(img)
|
136 |
+
else:
|
137 |
+
fps = None
|
138 |
+
if getattr(self.cfg.task, "save_predict", None):
|
139 |
+
self._save_image(img, batch_idx)
|
140 |
+
return img, fps
|
141 |
+
|
142 |
+
def _display_stream(self, img):
|
143 |
+
img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
|
144 |
+
fps = 1 / (time.time() - self.trainer.current_epoch_start_time)
|
145 |
+
cv2.putText(img, f"FPS: {fps:.2f}", (0, 15), 0, 0.5, (100, 255, 0), 1, cv2.LINE_AA)
|
146 |
+
cv2.imshow("Prediction", img)
|
147 |
+
if cv2.waitKey(1) & 0xFF == ord("q"):
|
148 |
+
self.trainer.should_stop = True
|
149 |
+
return fps
|
150 |
+
|
151 |
+
def _save_image(self, img, batch_idx):
|
152 |
+
save_image_path = Path(self.logger.save_dir) / f"frame{batch_idx:03d}.png"
|
153 |
+
img.save(save_image_path)
|
154 |
+
print(f"💾 Saved visualize image at {save_image_path}")
|