henry000 commited on
Commit
c6e3994
·
1 Parent(s): a757657

✅ [Pass] the test for lightning train and validate

Browse files
Files changed (1) hide show
  1. tests/test_tools/test_solver.py +20 -22
tests/test_tools/test_solver.py CHANGED
@@ -1,38 +1,39 @@
1
  import sys
 
2
  from pathlib import Path
3
 
4
  import pytest
5
- from torch import allclose, tensor
 
6
 
7
  project_root = Path(__file__).resolve().parent.parent.parent
8
  sys.path.append(str(project_root))
9
 
10
  from yolo.config.config import Config
11
  from yolo.model.yolo import YOLO
12
- from yolo.tools.data_loader import StreamDataLoader, YoloDataLoader
13
- from yolo.tools.solver import ModelTester, ModelTrainer, ModelValidator
14
  from yolo.utils.bounding_box_utils import Anc2Box, Vec2Box
15
 
16
 
17
  @pytest.fixture
18
- def model_validator(validation_cfg: Config, model: YOLO, vec2box: Vec2Box, validation_progress_logger, device):
19
- validator = ModelValidator(
20
- validation_cfg.task, validation_cfg.dataset, model, vec2box, validation_progress_logger, device
21
- )
22
  return validator
23
 
24
 
25
- def test_model_validator_initialization(model_validator: ModelValidator):
26
  assert isinstance(model_validator.model, YOLO)
27
- assert hasattr(model_validator, "solve")
28
 
29
 
30
- def test_model_validator_solve_mock_dataset(model_validator: ModelValidator, validation_dataloader: YoloDataLoader):
31
- mAPs = model_validator.solve(validation_dataloader)
32
- except_mAPs = {"mAP.5": tensor(0.6969), "mAP.5:.95": tensor(0.4195)}
33
- assert allclose(mAPs["mAP.5"], except_mAPs["mAP.5"], rtol=0.1)
34
- print(mAPs)
35
- assert allclose(mAPs["mAP.5:.95"], except_mAPs["mAP.5:.95"], rtol=0.1)
 
36
 
37
 
38
  @pytest.fixture
@@ -63,17 +64,14 @@ def test_modelv7_tester_solve_single_image(modelv7_tester: ModelTester, file_str
63
  @pytest.fixture
64
  def model_trainer(train_cfg: Config, model: YOLO, vec2box: Vec2Box, train_progress_logger, device):
65
  train_cfg.task.epoch = 2
66
- trainer = ModelTrainer(train_cfg, model, vec2box, train_progress_logger, device, use_ddp=False)
67
  return trainer
68
 
69
 
70
- def test_model_trainer_initialization(model_trainer: ModelTrainer):
71
-
72
  assert isinstance(model_trainer.model, YOLO)
73
- assert hasattr(model_trainer, "solve")
74
- assert model_trainer.optimizer is not None
75
- assert model_trainer.scheduler is not None
76
- assert model_trainer.loss_fn is not None
77
 
78
 
79
  # def test_model_trainer_solve_mock_dataset(model_trainer: ModelTrainer, train_dataloader: YoloDataLoader):
 
1
  import sys
2
+ from math import isclose
3
  from pathlib import Path
4
 
5
  import pytest
6
+ from lightning.pytorch import Trainer
7
+ from torch.utils.data import DataLoader
8
 
9
  project_root = Path(__file__).resolve().parent.parent.parent
10
  sys.path.append(str(project_root))
11
 
12
  from yolo.config.config import Config
13
  from yolo.model.yolo import YOLO
14
+ from yolo.tools.data_loader import StreamDataLoader
15
+ from yolo.tools.solver import TrainModel, ValidateModel
16
  from yolo.utils.bounding_box_utils import Anc2Box, Vec2Box
17
 
18
 
19
  @pytest.fixture
20
+ def model_validator(validation_cfg: Config):
21
+ validator = ValidateModel(validation_cfg)
 
 
22
  return validator
23
 
24
 
25
+ def test_model_validator_initialization(solver: Trainer, model_validator: ValidateModel):
26
  assert isinstance(model_validator.model, YOLO)
27
+ assert hasattr(solver, "validate")
28
 
29
 
30
+ def test_model_validator_solve_mock_dataset(
31
+ solver: Trainer, model_validator: ValidateModel, validation_dataloader: DataLoader
32
+ ):
33
+ mAPs = solver.validate(model_validator, dataloaders=validation_dataloader)[0]
34
+ except_mAPs = {"map_50": 0.7379, "map": 0.5617}
35
+ assert isclose(mAPs["map_50"], except_mAPs["map_50"], abs_tol=1e-4)
36
+ assert isclose(mAPs["map"], except_mAPs["map"], abs_tol=1e-4)
37
 
38
 
39
  @pytest.fixture
 
64
  @pytest.fixture
65
  def model_trainer(train_cfg: Config, model: YOLO, vec2box: Vec2Box, train_progress_logger, device):
66
  train_cfg.task.epoch = 2
67
+ trainer = TrainModel(train_cfg)
68
  return trainer
69
 
70
 
71
+ def test_model_trainer_initialization(solver: Trainer, model_trainer: TrainModel):
 
72
  assert isinstance(model_trainer.model, YOLO)
73
+ assert hasattr(solver, "fit")
74
+ assert solver.optimizers is not None
 
 
75
 
76
 
77
  # def test_model_trainer_solve_mock_dataset(model_trainer: ModelTrainer, train_dataloader: YoloDataLoader):