|
import sys |
|
from pathlib import Path |
|
from unittest.mock import MagicMock, patch |
|
|
|
import pytest |
|
import torch |
|
from hydra import compose, initialize |
|
|
|
project_root = Path(__file__).resolve().parent.parent.parent |
|
sys.path.append(str(project_root)) |
|
|
|
from yolo.config.config import ( |
|
Config, |
|
DataConfig, |
|
LossConfig, |
|
TrainConfig, |
|
ValidationConfig, |
|
) |
|
from yolo.model.yolo import YOLO, create_model |
|
from yolo.tools.data_loader import create_dataloader |
|
from yolo.tools.loss_functions import create_loss_function |
|
from yolo.tools.solver import ( |
|
ModelTester, |
|
ModelTrainer, |
|
ModelValidator, |
|
) |
|
from yolo.utils.bounding_box_utils import Vec2Box |
|
from yolo.utils.logging_utils import ProgressLogger |
|
from yolo.utils.model_utils import ( |
|
ExponentialMovingAverage, |
|
create_optimizer, |
|
create_scheduler, |
|
) |
|
|
|
|
|
@pytest.fixture |
|
def cfg() -> Config: |
|
with initialize(config_path="../../yolo/config", version_base=None): |
|
cfg: Config = compose(config_name="config") |
|
cfg.weight = None |
|
return cfg |
|
|
|
|
|
@pytest.fixture |
|
def cfg_validaion() -> Config: |
|
with initialize(config_path="../../yolo/config", version_base=None): |
|
cfg: Config = compose(config_name="config", overrides=["task=validation"]) |
|
cfg.weight = None |
|
return cfg |
|
|
|
|
|
@pytest.fixture |
|
def cfg_inference() -> Config: |
|
with initialize(config_path="../../yolo/config", version_base=None): |
|
cfg: Config = compose(config_name="config", overrides=["task=inference"]) |
|
cfg.weight = None |
|
return cfg |
|
|
|
|
|
@pytest.fixture |
|
def device() -> torch.device: |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
return device |
|
|
|
|
|
@pytest.fixture |
|
def model(cfg: Config, device) -> YOLO: |
|
model = create_model(cfg.model, weight_path=None) |
|
return model.to(device) |
|
|
|
|
|
@pytest.fixture |
|
def vec2box(cfg: Config, model: YOLO, device) -> Vec2Box: |
|
model = create_model(cfg.model, weight_path=None).to(device) |
|
vec2box = Vec2Box(model, cfg.image_size, device) |
|
return vec2box |
|
|
|
|
|
@pytest.fixture |
|
def progress_logger(cfg: Config): |
|
progress_logger = ProgressLogger(cfg, exp_name=cfg.name) |
|
return progress_logger |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_model_validator_initialization(cfg_validaion: Config, model: YOLO, vec2box: Vec2Box, progress_logger, device): |
|
validator = ModelValidator(cfg_validaion.task, cfg_validaion.dataset, model, vec2box, progress_logger, device) |
|
assert validator.model == model |
|
assert validator.device == device |
|
assert validator.progress == progress_logger |
|
|
|
|
|
def test_model_tester_initialization(cfg_inference: Config, model: YOLO, vec2box: Vec2Box, progress_logger, device): |
|
tester = ModelTester(cfg_inference, model, vec2box, progress_logger, device) |
|
assert tester.model == model |
|
assert tester.device == device |
|
assert tester.progress == progress_logger |
|
|