✨ [New] validation code! run my pycocotools
Browse files- yolo/lazy.py +5 -1
- yolo/tools/solver.py +13 -13
- yolo/utils/logging_utils.py +2 -2
yolo/lazy.py
CHANGED
@@ -9,7 +9,7 @@ sys.path.append(str(project_root))
|
|
9 |
from yolo.config.config import Config
|
10 |
from yolo.model.yolo import create_model
|
11 |
from yolo.tools.data_loader import create_dataloader
|
12 |
-
from yolo.tools.solver import ModelTester, ModelTrainer
|
13 |
from yolo.utils.bounding_box_utils import Vec2Box
|
14 |
from yolo.utils.deploy_utils import FastModelLoader
|
15 |
from yolo.utils.logging_utils import ProgressLogger
|
@@ -37,6 +37,10 @@ def main(cfg: Config):
|
|
37 |
tester = ModelTester(cfg, model, vec2box, progress, device)
|
38 |
tester.solve(dataloader)
|
39 |
|
|
|
|
|
|
|
|
|
40 |
|
41 |
if __name__ == "__main__":
|
42 |
main()
|
|
|
9 |
from yolo.config.config import Config
|
10 |
from yolo.model.yolo import create_model
|
11 |
from yolo.tools.data_loader import create_dataloader
|
12 |
+
from yolo.tools.solver import ModelTester, ModelTrainer, ModelValidator
|
13 |
from yolo.utils.bounding_box_utils import Vec2Box
|
14 |
from yolo.utils.deploy_utils import FastModelLoader
|
15 |
from yolo.utils.logging_utils import ProgressLogger
|
|
|
37 |
tester = ModelTester(cfg, model, vec2box, progress, device)
|
38 |
tester.solve(dataloader)
|
39 |
|
40 |
+
if cfg.task.task == "validation":
|
41 |
+
valider = ModelValidator(cfg.task, model, vec2box, progress, device)
|
42 |
+
valider.solve(dataloader)
|
43 |
+
|
44 |
|
45 |
if __name__ == "__main__":
|
46 |
main()
|
yolo/tools/solver.py
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
import os
|
2 |
import time
|
3 |
|
@@ -15,12 +16,14 @@ from yolo.model.yolo import YOLO
|
|
15 |
from yolo.tools.data_loader import StreamDataLoader, create_dataloader
|
16 |
from yolo.tools.drawer import draw_bboxes, draw_model
|
17 |
from yolo.tools.loss_functions import create_loss_function
|
18 |
-
from yolo.utils.bounding_box_utils import Vec2Box
|
19 |
from yolo.utils.logging_utils import ProgressLogger, log_model_structure
|
20 |
from yolo.utils.model_utils import (
|
21 |
ExponentialMovingAverage,
|
|
|
22 |
create_optimizer,
|
23 |
create_scheduler,
|
|
|
24 |
)
|
25 |
|
26 |
|
@@ -176,32 +179,29 @@ class ModelValidator:
|
|
176 |
validation_cfg: ValidationConfig,
|
177 |
model: YOLO,
|
178 |
vec2box: Vec2Box,
|
179 |
-
device,
|
180 |
progress: ProgressLogger,
|
|
|
181 |
):
|
182 |
self.model = model
|
183 |
-
self.vec2box = vec2box
|
184 |
self.device = device
|
185 |
self.progress = progress
|
186 |
|
187 |
-
self.
|
|
|
188 |
|
189 |
def solve(self, dataloader):
|
190 |
# logger.info("🧪 Start Validation!")
|
191 |
self.model.eval()
|
192 |
-
|
193 |
-
iou_thresholds = torch.arange(0.5, 1.0, 0.05)
|
194 |
-
map_all = []
|
195 |
self.progress.start_one_epoch(len(dataloader))
|
196 |
for images, targets, rev_tensor, img_paths in dataloader:
|
197 |
images, targets, rev_tensor = images.to(self.device), targets.to(self.device), rev_tensor.to(self.device)
|
198 |
with torch.no_grad():
|
199 |
predicts = self.model(images)
|
200 |
-
|
201 |
-
|
202 |
-
for idx, predict in enumerate(nms_out):
|
203 |
-
map_value = calculate_map(predict, targets[idx], iou_thresholds)
|
204 |
-
map_all.append(map_value[0])
|
205 |
-
self.progress.one_batch(mapp=torch.Tensor(map_all).mean())
|
206 |
|
|
|
207 |
self.progress.finish_one_epoch()
|
|
|
|
|
|
1 |
+
import json
|
2 |
import os
|
3 |
import time
|
4 |
|
|
|
16 |
from yolo.tools.data_loader import StreamDataLoader, create_dataloader
|
17 |
from yolo.tools.drawer import draw_bboxes, draw_model
|
18 |
from yolo.tools.loss_functions import create_loss_function
|
19 |
+
from yolo.utils.bounding_box_utils import Vec2Box
|
20 |
from yolo.utils.logging_utils import ProgressLogger, log_model_structure
|
21 |
from yolo.utils.model_utils import (
|
22 |
ExponentialMovingAverage,
|
23 |
+
PostProccess,
|
24 |
create_optimizer,
|
25 |
create_scheduler,
|
26 |
+
predicts_to_json,
|
27 |
)
|
28 |
|
29 |
|
|
|
179 |
validation_cfg: ValidationConfig,
|
180 |
model: YOLO,
|
181 |
vec2box: Vec2Box,
|
|
|
182 |
progress: ProgressLogger,
|
183 |
+
device,
|
184 |
):
|
185 |
self.model = model
|
|
|
186 |
self.device = device
|
187 |
self.progress = progress
|
188 |
|
189 |
+
self.post_proccess = PostProccess(vec2box, validation_cfg.nms)
|
190 |
+
self.json_path = os.path.join(self.progress.save_path, f"predict.json")
|
191 |
|
192 |
def solve(self, dataloader):
|
193 |
# logger.info("🧪 Start Validation!")
|
194 |
self.model.eval()
|
195 |
+
predict_json = []
|
|
|
|
|
196 |
self.progress.start_one_epoch(len(dataloader))
|
197 |
for images, targets, rev_tensor, img_paths in dataloader:
|
198 |
images, targets, rev_tensor = images.to(self.device), targets.to(self.device), rev_tensor.to(self.device)
|
199 |
with torch.no_grad():
|
200 |
predicts = self.model(images)
|
201 |
+
predicts = self.post_proccess(predicts, rev_tensor)
|
202 |
+
self.progress.one_batch()
|
|
|
|
|
|
|
|
|
203 |
|
204 |
+
predict_json.extend(predicts_to_json(img_paths, predicts))
|
205 |
self.progress.finish_one_epoch()
|
206 |
+
with open(self.json_path, "w") as f:
|
207 |
+
json.dump(predict_json, f)
|
yolo/utils/logging_utils.py
CHANGED
@@ -70,9 +70,9 @@ class ProgressLogger:
|
|
70 |
self.wandb.log({f"Learning Rate/{lr_name}": lr_value}, step=epoch_idx)
|
71 |
self.batch_task = self.progress.add_task("[green]Batches", total=num_batches)
|
72 |
|
73 |
-
def one_batch(self, loss_dict: Dict[str, Tensor] = None
|
74 |
if loss_dict is None:
|
75 |
-
self.progress.update(self.batch_task, advance=1, description=f"[green]
|
76 |
return
|
77 |
if self.use_wandb:
|
78 |
for loss_name, loss_value in loss_dict.items():
|
|
|
70 |
self.wandb.log({f"Learning Rate/{lr_name}": lr_value}, step=epoch_idx)
|
71 |
self.batch_task = self.progress.add_task("[green]Batches", total=num_batches)
|
72 |
|
73 |
+
def one_batch(self, loss_dict: Dict[str, Tensor] = None):
|
74 |
if loss_dict is None:
|
75 |
+
self.progress.update(self.batch_task, advance=1, description=f"[green]Validating")
|
76 |
return
|
77 |
if self.use_wandb:
|
78 |
for loss_name, loss_value in loss_dict.items():
|