File size: 3,676 Bytes
88e45b9 b038f54 88e45b9 b038f54 88e45b9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 |
import sys
from pathlib import Path
from unittest.mock import MagicMock, patch
import pytest
import torch
from hydra import compose, initialize
project_root = Path(__file__).resolve().parent.parent.parent
sys.path.append(str(project_root))
from yolo.config.config import (
Config,
DataConfig,
LossConfig,
TrainConfig,
ValidationConfig,
)
from yolo.model.yolo import YOLO, create_model
from yolo.tools.data_loader import create_dataloader
from yolo.tools.loss_functions import create_loss_function
from yolo.tools.solver import ( # Adjust the import to your module
ModelTester,
ModelTrainer,
ModelValidator,
)
from yolo.utils.bounding_box_utils import Vec2Box
from yolo.utils.logging_utils import ProgressLogger
from yolo.utils.model_utils import (
ExponentialMovingAverage,
create_optimizer,
create_scheduler,
)
@pytest.fixture
def cfg() -> Config:
with initialize(config_path="../../yolo/config", version_base=None):
cfg: Config = compose(config_name="config")
cfg.weight = None
return cfg
@pytest.fixture
def cfg_validaion() -> Config:
with initialize(config_path="../../yolo/config", version_base=None):
cfg: Config = compose(config_name="config", overrides=["task=validation"])
cfg.weight = None
return cfg
@pytest.fixture
def cfg_inference() -> Config:
with initialize(config_path="../../yolo/config", version_base=None):
cfg: Config = compose(config_name="config", overrides=["task=inference"])
cfg.weight = None
return cfg
@pytest.fixture
def device() -> torch.device:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
return device
@pytest.fixture
def model(cfg: Config, device) -> YOLO:
model = create_model(cfg.model, weight_path=None)
return model.to(device)
@pytest.fixture
def vec2box(cfg: Config, model: YOLO, device) -> Vec2Box:
model = create_model(cfg.model, weight_path=None).to(device)
vec2box = Vec2Box(model, cfg.image_size, device)
return vec2box
@pytest.fixture
def progress_logger(cfg: Config):
progress_logger = ProgressLogger(cfg, exp_name=cfg.name)
return progress_logger
# def test_model_trainer_initialization(cfg: Config, model: YOLO, vec2box: Vec2Box, progress_logger, device):
# trainer = ModelTrainer(cfg, model, vec2box, progress_logger, device, use_ddp=False)
# assert trainer.model == model
# assert trainer.device == device
# assert trainer.optimizer is not None
# assert trainer.scheduler is not None
# assert trainer.loss_fn is not None
# assert trainer.progress == progress_logger
# def test_model_trainer_train_one_batch(config, model, vec2box, progress_logger, device):
# trainer = ModelTrainer(config, model, vec2box, progress_logger, device, use_ddp=False)
# images = torch.rand(1, 3, 224, 224)
# targets = torch.rand(1, 5)
# loss_item = trainer.train_one_batch(images, targets)
# assert isinstance(loss_item, dict)
def test_model_validator_initialization(cfg_validaion: Config, model: YOLO, vec2box: Vec2Box, progress_logger, device):
validator = ModelValidator(cfg_validaion.task, cfg_validaion.dataset, model, vec2box, progress_logger, device)
assert validator.model == model
assert validator.device == device
assert validator.progress == progress_logger
def test_model_tester_initialization(cfg_inference: Config, model: YOLO, vec2box: Vec2Box, progress_logger, device):
tester = ModelTester(cfg_inference, model, vec2box, progress_logger, device)
assert tester.model == model
assert tester.device == device
assert tester.progress == progress_logger
|