YOLO / tests /test_tools /test_solver.py
henry000's picture
πŸ’š [Fix] CI build bugs, temporarily remove MPS
cdd9a11
raw
history blame
2.56 kB
import sys
from math import isclose
from pathlib import Path
import pytest
from lightning.pytorch import Trainer
from torch.utils.data import DataLoader
project_root = Path(__file__).resolve().parent.parent.parent
sys.path.append(str(project_root))
from yolo.config.config import Config
from yolo.model.yolo import YOLO
from yolo.tools.data_loader import StreamDataLoader
from yolo.tools.solver import InferenceModel, TrainModel, ValidateModel
from yolo.utils.bounding_box_utils import Anc2Box, Vec2Box
@pytest.fixture
def model_validator(validation_cfg: Config):
validator = ValidateModel(validation_cfg)
return validator
def test_model_validator_initialization(solver: Trainer, model_validator: ValidateModel):
assert isinstance(model_validator.model, YOLO)
assert hasattr(solver, "validate")
def test_model_validator_solve_mock_dataset(
solver: Trainer, model_validator: ValidateModel, validation_dataloader: DataLoader
):
mAPs = solver.validate(model_validator, dataloaders=validation_dataloader)[0]
except_mAPs = {"map_50": 0.7379, "map": 0.5617}
assert isclose(mAPs["map_50"], except_mAPs["map_50"], abs_tol=0.1)
assert isclose(mAPs["map"], except_mAPs["map"], abs_tol=0.1)
@pytest.fixture
def model_tester(inference_cfg: Config):
tester = InferenceModel(inference_cfg)
return tester
@pytest.fixture
def modelv7_tester(inference_v7_cfg: Config):
tester = InferenceModel(inference_v7_cfg)
return tester
def test_model_tester_initialization(solver: Trainer, model_tester: InferenceModel):
assert isinstance(model_tester.model, YOLO)
assert hasattr(solver, "predict")
def test_model_tester_solve_single_image(
solver: Trainer, model_tester: InferenceModel, file_stream_data_loader: StreamDataLoader
):
solver.predict(model_tester, file_stream_data_loader)
def test_modelv7_tester_solve_single_image(
solver: Trainer, modelv7_tester: InferenceModel, file_stream_data_loader_v7: StreamDataLoader
):
solver.predict(modelv7_tester, file_stream_data_loader_v7)
@pytest.fixture
def model_trainer(train_cfg: Config):
train_cfg.task.epoch = 2
trainer = TrainModel(train_cfg)
return trainer
def test_model_trainer_initialization(solver: Trainer, model_trainer: TrainModel):
assert isinstance(model_trainer.model, YOLO)
assert hasattr(solver, "fit")
assert solver.optimizers is not None
# def test_model_trainer_solve_mock_dataset(model_trainer: ModelTrainer, train_dataloader: YoloDataLoader):
# model_trainer.solve(train_dataloader)