✅ [Pass] the test for lightning train and validate
Browse files- tests/test_tools/test_solver.py +20 -22
tests/test_tools/test_solver.py
CHANGED
@@ -1,38 +1,39 @@
|
|
1 |
import sys
|
|
|
2 |
from pathlib import Path
|
3 |
|
4 |
import pytest
|
5 |
-
from
|
|
|
6 |
|
7 |
project_root = Path(__file__).resolve().parent.parent.parent
|
8 |
sys.path.append(str(project_root))
|
9 |
|
10 |
from yolo.config.config import Config
|
11 |
from yolo.model.yolo import YOLO
|
12 |
-
from yolo.tools.data_loader import StreamDataLoader
|
13 |
-
from yolo.tools.solver import
|
14 |
from yolo.utils.bounding_box_utils import Anc2Box, Vec2Box
|
15 |
|
16 |
|
17 |
@pytest.fixture
|
18 |
-
def model_validator(validation_cfg: Config
|
19 |
-
validator =
|
20 |
-
validation_cfg.task, validation_cfg.dataset, model, vec2box, validation_progress_logger, device
|
21 |
-
)
|
22 |
return validator
|
23 |
|
24 |
|
25 |
-
def test_model_validator_initialization(model_validator:
|
26 |
assert isinstance(model_validator.model, YOLO)
|
27 |
-
assert hasattr(
|
28 |
|
29 |
|
30 |
-
def test_model_validator_solve_mock_dataset(
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
assert
|
|
|
36 |
|
37 |
|
38 |
@pytest.fixture
|
@@ -63,17 +64,14 @@ def test_modelv7_tester_solve_single_image(modelv7_tester: ModelTester, file_str
|
|
63 |
@pytest.fixture
|
64 |
def model_trainer(train_cfg: Config, model: YOLO, vec2box: Vec2Box, train_progress_logger, device):
|
65 |
train_cfg.task.epoch = 2
|
66 |
-
trainer =
|
67 |
return trainer
|
68 |
|
69 |
|
70 |
-
def test_model_trainer_initialization(model_trainer:
|
71 |
-
|
72 |
assert isinstance(model_trainer.model, YOLO)
|
73 |
-
assert hasattr(
|
74 |
-
assert
|
75 |
-
assert model_trainer.scheduler is not None
|
76 |
-
assert model_trainer.loss_fn is not None
|
77 |
|
78 |
|
79 |
# def test_model_trainer_solve_mock_dataset(model_trainer: ModelTrainer, train_dataloader: YoloDataLoader):
|
|
|
1 |
import sys
|
2 |
+
from math import isclose
|
3 |
from pathlib import Path
|
4 |
|
5 |
import pytest
|
6 |
+
from lightning.pytorch import Trainer
|
7 |
+
from torch.utils.data import DataLoader
|
8 |
|
9 |
project_root = Path(__file__).resolve().parent.parent.parent
|
10 |
sys.path.append(str(project_root))
|
11 |
|
12 |
from yolo.config.config import Config
|
13 |
from yolo.model.yolo import YOLO
|
14 |
+
from yolo.tools.data_loader import StreamDataLoader
|
15 |
+
from yolo.tools.solver import TrainModel, ValidateModel
|
16 |
from yolo.utils.bounding_box_utils import Anc2Box, Vec2Box
|
17 |
|
18 |
|
19 |
@pytest.fixture
|
20 |
+
def model_validator(validation_cfg: Config):
|
21 |
+
validator = ValidateModel(validation_cfg)
|
|
|
|
|
22 |
return validator
|
23 |
|
24 |
|
25 |
+
def test_model_validator_initialization(solver: Trainer, model_validator: ValidateModel):
|
26 |
assert isinstance(model_validator.model, YOLO)
|
27 |
+
assert hasattr(solver, "validate")
|
28 |
|
29 |
|
30 |
+
def test_model_validator_solve_mock_dataset(
|
31 |
+
solver: Trainer, model_validator: ValidateModel, validation_dataloader: DataLoader
|
32 |
+
):
|
33 |
+
mAPs = solver.validate(model_validator, dataloaders=validation_dataloader)[0]
|
34 |
+
except_mAPs = {"map_50": 0.7379, "map": 0.5617}
|
35 |
+
assert isclose(mAPs["map_50"], except_mAPs["map_50"], abs_tol=1e-4)
|
36 |
+
assert isclose(mAPs["map"], except_mAPs["map"], abs_tol=1e-4)
|
37 |
|
38 |
|
39 |
@pytest.fixture
|
|
|
64 |
@pytest.fixture
|
65 |
def model_trainer(train_cfg: Config, model: YOLO, vec2box: Vec2Box, train_progress_logger, device):
|
66 |
train_cfg.task.epoch = 2
|
67 |
+
trainer = TrainModel(train_cfg)
|
68 |
return trainer
|
69 |
|
70 |
|
71 |
+
def test_model_trainer_initialization(solver: Trainer, model_trainer: TrainModel):
|
|
|
72 |
assert isinstance(model_trainer.model, YOLO)
|
73 |
+
assert hasattr(solver, "fit")
|
74 |
+
assert solver.optimizers is not None
|
|
|
|
|
75 |
|
76 |
|
77 |
# def test_model_trainer_solve_mock_dataset(model_trainer: ModelTrainer, train_dataloader: YoloDataLoader):
|