YOLO / tests /test_tools /test_solver.py
henry000's picture
✅ [Pass] test for v7 structure!
4b46de4
raw
history blame
3.03 kB
import sys
from pathlib import Path
import pytest
from torch import allclose, tensor
project_root = Path(__file__).resolve().parent.parent.parent
sys.path.append(str(project_root))
from yolo.config.config import Config
from yolo.model.yolo import YOLO
from yolo.tools.data_loader import StreamDataLoader, YoloDataLoader
from yolo.tools.solver import ModelTester, ModelTrainer, ModelValidator
from yolo.utils.bounding_box_utils import Anc2Box, Vec2Box
@pytest.fixture
def model_validator(validation_cfg: Config, model: YOLO, vec2box: Vec2Box, validation_progress_logger, device):
validator = ModelValidator(
validation_cfg.task, validation_cfg.dataset, model, vec2box, validation_progress_logger, device
)
return validator
def test_model_validator_initialization(model_validator: ModelValidator):
assert isinstance(model_validator.model, YOLO)
assert hasattr(model_validator, "solve")
def test_model_validator_solve_mock_dataset(model_validator: ModelValidator, validation_dataloader: YoloDataLoader):
mAPs = model_validator.solve(validation_dataloader)
except_mAPs = {"mAP.5": tensor(0.6969), "mAP.5:.95": tensor(0.4195)}
assert allclose(mAPs["mAP.5"], except_mAPs["mAP.5"], rtol=0.1)
print(mAPs)
assert allclose(mAPs["mAP.5:.95"], except_mAPs["mAP.5:.95"], rtol=0.1)
@pytest.fixture
def model_tester(inference_cfg: Config, model: YOLO, vec2box: Vec2Box, validation_progress_logger, device):
tester = ModelTester(inference_cfg, model, vec2box, validation_progress_logger, device)
return tester
@pytest.fixture
def modelv7_tester(inference_v7_cfg: Config, model_v7: YOLO, anc2box: Anc2Box, validation_progress_logger, device):
tester = ModelTester(inference_v7_cfg, model_v7, anc2box, validation_progress_logger, device)
return tester
def test_model_tester_initialization(model_tester: ModelTester):
assert isinstance(model_tester.model, YOLO)
assert hasattr(model_tester, "solve")
def test_model_tester_solve_single_image(model_tester: ModelTester, file_stream_data_loader: StreamDataLoader):
model_tester.solve(file_stream_data_loader)
def test_modelv7_tester_solve_single_image(modelv7_tester: ModelTester, file_stream_data_loader_v7: StreamDataLoader):
modelv7_tester.solve(file_stream_data_loader_v7)
@pytest.fixture
def model_trainer(train_cfg: Config, model: YOLO, vec2box: Vec2Box, train_progress_logger, device):
train_cfg.task.epoch = 2
trainer = ModelTrainer(train_cfg, model, vec2box, train_progress_logger, device, use_ddp=False)
return trainer
def test_model_trainer_initialization(model_trainer: ModelTrainer):
assert isinstance(model_trainer.model, YOLO)
assert hasattr(model_trainer, "solve")
assert model_trainer.optimizer is not None
assert model_trainer.scheduler is not None
assert model_trainer.loss_fn is not None
# def test_model_trainer_solve_mock_dataset(model_trainer: ModelTrainer, train_dataloader: YoloDataLoader):
# model_trainer.solve(train_dataloader)