File size: 3,794 Bytes
6a39ae1 ac8e6e6 6a39ae1 4b46de4 6a39ae1 ac8e6e6 44abd6c ac8e6e6 6a39ae1 4b46de4 6a39ae1 4b46de4 ac8e6e6 3ebbbd9 ac8e6e6 bd0409b ac8e6e6 3ebbbd9 ac8e6e6 6a39ae1 4b46de4 6a39ae1 4b46de4 6a39ae1 44abd6c ac8e6e6 6a39ae1 44abd6c ac8e6e6 6a39ae1 4b46de4 6a39ae1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 |
import sys
from pathlib import Path
import pytest
import torch
from hydra import compose, initialize
from lightning import Trainer
project_root = Path(__file__).resolve().parent.parent
sys.path.append(str(project_root))
from yolo import Anc2Box, Config, Vec2Box, create_converter, create_model
from yolo.model.yolo import YOLO
from yolo.tools.data_loader import StreamDataLoader, create_dataloader
from yolo.tools.dataset_preparation import prepare_dataset
from yolo.utils.logging_utils import set_seed, setup
def pytest_configure(config):
config.addinivalue_line("markers", "requires_cuda: mark test to run only if CUDA is available")
def get_cfg(overrides=[]) -> Config:
config_path = "../yolo/config"
with initialize(config_path=config_path, version_base=None):
cfg: Config = compose(config_name="config", overrides=overrides)
set_seed(cfg.lucky_number)
return cfg
@pytest.fixture(scope="session")
def train_cfg() -> Config:
return get_cfg(overrides=["task=train", "dataset=mock"])
@pytest.fixture(scope="session")
def validation_cfg():
return get_cfg(overrides=["task=validation", "dataset=mock"])
@pytest.fixture(scope="session")
def inference_cfg():
return get_cfg(overrides=["task=inference"])
@pytest.fixture(scope="session")
def inference_v7_cfg():
return get_cfg(overrides=["task=inference", "model=v7"])
@pytest.fixture(scope="session")
def device():
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
@pytest.fixture(scope="session")
def model(train_cfg: Config, device) -> YOLO:
model = create_model(train_cfg.model)
return model.to(device)
@pytest.fixture(scope="session")
def model_v7(inference_v7_cfg: Config, device) -> YOLO:
model = create_model(inference_v7_cfg.model)
return model.to(device)
@pytest.fixture(scope="session")
def solver(train_cfg: Config) -> Trainer:
train_cfg.use_wandb = False
callbacks, loggers, save_path = setup(train_cfg)
trainer = Trainer(
accelerator="auto",
max_epochs=getattr(train_cfg.task, "epoch", None),
precision="16-mixed",
callbacks=callbacks,
logger=loggers,
log_every_n_steps=1,
gradient_clip_val=10,
deterministic=True,
default_root_dir=save_path,
)
return trainer
@pytest.fixture(scope="session")
def vec2box(train_cfg: Config, model: YOLO, device) -> Vec2Box:
vec2box = create_converter(train_cfg.model.name, model, train_cfg.model.anchor, train_cfg.image_size, device)
return vec2box
@pytest.fixture(scope="session")
def anc2box(inference_v7_cfg: Config, model: YOLO, device) -> Anc2Box:
anc2box = create_converter(
inference_v7_cfg.model.name, model, inference_v7_cfg.model.anchor, inference_v7_cfg.image_size, device
)
return anc2box
@pytest.fixture(scope="session")
def train_dataloader(train_cfg: Config):
prepare_dataset(train_cfg.dataset, task="train")
return create_dataloader(train_cfg.task.data, train_cfg.dataset, train_cfg.task.task)
@pytest.fixture(scope="session")
def validation_dataloader(validation_cfg: Config):
prepare_dataset(validation_cfg.dataset, task="val")
return create_dataloader(validation_cfg.task.data, validation_cfg.dataset, validation_cfg.task.task)
@pytest.fixture(scope="session")
def file_stream_data_loader(inference_cfg: Config):
return StreamDataLoader(inference_cfg.task.data)
@pytest.fixture(scope="session")
def file_stream_data_loader_v7(inference_v7_cfg: Config):
return StreamDataLoader(inference_v7_cfg.task.data)
@pytest.fixture(scope="session")
def directory_stream_data_loader(inference_cfg: Config):
inference_cfg.task.data.source = "tests/data/images/train"
return StreamDataLoader(inference_cfg.task.data)
|