File size: 3,643 Bytes
6a39ae1 4b46de4 6a39ae1 44abd6c 6a39ae1 4b46de4 6a39ae1 4b46de4 6a39ae1 4b46de4 6a39ae1 4b46de4 6a39ae1 44abd6c 6a39ae1 44abd6c 6a39ae1 4b46de4 6a39ae1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 |
import sys
from pathlib import Path
import pytest
import torch
from hydra import compose, initialize
project_root = Path(__file__).resolve().parent.parent
sys.path.append(str(project_root))
from yolo import Anc2Box, Config, Vec2Box, create_converter, create_model
from yolo.model.yolo import YOLO
from yolo.tools.data_loader import StreamDataLoader, YoloDataLoader
from yolo.tools.dataset_preparation import prepare_dataset
from yolo.utils.logging_utils import ProgressLogger, set_seed
def pytest_configure(config):
config.addinivalue_line("markers", "requires_cuda: mark test to run only if CUDA is available")
def get_cfg(overrides=[]) -> Config:
config_path = "../yolo/config"
with initialize(config_path=config_path, version_base=None):
cfg: Config = compose(config_name="config", overrides=overrides)
set_seed(cfg.lucky_number)
return cfg
@pytest.fixture(scope="session")
def train_cfg() -> Config:
return get_cfg(overrides=["task=train", "dataset=mock"])
@pytest.fixture(scope="session")
def validation_cfg():
return get_cfg(overrides=["task=validation", "dataset=mock"])
@pytest.fixture(scope="session")
def inference_cfg():
return get_cfg(overrides=["task=inference"])
@pytest.fixture(scope="session")
def inference_v7_cfg():
return get_cfg(overrides=["task=inference", "model=v7"])
@pytest.fixture(scope="session")
def device():
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
@pytest.fixture(scope="session")
def train_progress_logger(train_cfg: Config):
progress_logger = ProgressLogger(train_cfg, exp_name=train_cfg.name)
return progress_logger
@pytest.fixture(scope="session")
def validation_progress_logger(validation_cfg: Config):
progress_logger = ProgressLogger(validation_cfg, exp_name=validation_cfg.name)
return progress_logger
@pytest.fixture(scope="session")
def model(train_cfg: Config, device) -> YOLO:
model = create_model(train_cfg.model)
return model.to(device)
@pytest.fixture(scope="session")
def model_v7(inference_v7_cfg: Config, device) -> YOLO:
model = create_model(inference_v7_cfg.model)
return model.to(device)
@pytest.fixture(scope="session")
def vec2box(train_cfg: Config, model: YOLO, device) -> Vec2Box:
vec2box = create_converter(train_cfg.model.name, model, train_cfg.model.anchor, train_cfg.image_size, device)
return vec2box
@pytest.fixture(scope="session")
def anc2box(inference_v7_cfg: Config, model: YOLO, device) -> Anc2Box:
anc2box = create_converter(
inference_v7_cfg.model.name, model, inference_v7_cfg.model.anchor, inference_v7_cfg.image_size, device
)
return anc2box
@pytest.fixture(scope="session")
def train_dataloader(train_cfg: Config):
prepare_dataset(train_cfg.dataset, task="train")
return YoloDataLoader(train_cfg.task.data, train_cfg.dataset, train_cfg.task.task)
@pytest.fixture(scope="session")
def validation_dataloader(validation_cfg: Config):
prepare_dataset(validation_cfg.dataset, task="val")
return YoloDataLoader(validation_cfg.task.data, validation_cfg.dataset, validation_cfg.task.task)
@pytest.fixture(scope="session")
def file_stream_data_loader(inference_cfg: Config):
return StreamDataLoader(inference_cfg.task.data)
@pytest.fixture(scope="session")
def file_stream_data_loader_v7(inference_v7_cfg: Config):
return StreamDataLoader(inference_v7_cfg.task.data)
@pytest.fixture(scope="session")
def directory_stream_data_loader(inference_cfg: Config):
inference_cfg.task.data.source = "tests/data/images/train"
return StreamDataLoader(inference_cfg.task.data)
|