File size: 3,643 Bytes
6a39ae1
 
 
 
 
 
 
 
 
 
4b46de4
6a39ae1
 
44abd6c
6a39ae1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4b46de4
 
 
 
 
6a39ae1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4b46de4
 
 
 
 
 
6a39ae1
 
4b46de4
6a39ae1
 
 
4b46de4
 
 
 
 
 
 
 
6a39ae1
 
44abd6c
6a39ae1
 
 
 
 
44abd6c
6a39ae1
 
 
 
 
 
 
 
4b46de4
 
 
 
 
6a39ae1
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import sys
from pathlib import Path

import pytest
import torch
from hydra import compose, initialize

project_root = Path(__file__).resolve().parent.parent
sys.path.append(str(project_root))

from yolo import Anc2Box, Config, Vec2Box, create_converter, create_model
from yolo.model.yolo import YOLO
from yolo.tools.data_loader import StreamDataLoader, YoloDataLoader
from yolo.tools.dataset_preparation import prepare_dataset
from yolo.utils.logging_utils import ProgressLogger, set_seed


def pytest_configure(config):
    config.addinivalue_line("markers", "requires_cuda: mark test to run only if CUDA is available")


def get_cfg(overrides=[]) -> Config:
    config_path = "../yolo/config"
    with initialize(config_path=config_path, version_base=None):
        cfg: Config = compose(config_name="config", overrides=overrides)
        set_seed(cfg.lucky_number)
        return cfg


@pytest.fixture(scope="session")
def train_cfg() -> Config:
    return get_cfg(overrides=["task=train", "dataset=mock"])


@pytest.fixture(scope="session")
def validation_cfg():
    return get_cfg(overrides=["task=validation", "dataset=mock"])


@pytest.fixture(scope="session")
def inference_cfg():
    return get_cfg(overrides=["task=inference"])


@pytest.fixture(scope="session")
def inference_v7_cfg():
    return get_cfg(overrides=["task=inference", "model=v7"])


@pytest.fixture(scope="session")
def device():
    return torch.device("cuda" if torch.cuda.is_available() else "cpu")


@pytest.fixture(scope="session")
def train_progress_logger(train_cfg: Config):
    progress_logger = ProgressLogger(train_cfg, exp_name=train_cfg.name)
    return progress_logger


@pytest.fixture(scope="session")
def validation_progress_logger(validation_cfg: Config):
    progress_logger = ProgressLogger(validation_cfg, exp_name=validation_cfg.name)
    return progress_logger


@pytest.fixture(scope="session")
def model(train_cfg: Config, device) -> YOLO:
    model = create_model(train_cfg.model)
    return model.to(device)


@pytest.fixture(scope="session")
def model_v7(inference_v7_cfg: Config, device) -> YOLO:
    model = create_model(inference_v7_cfg.model)
    return model.to(device)


@pytest.fixture(scope="session")
def vec2box(train_cfg: Config, model: YOLO, device) -> Vec2Box:
    vec2box = create_converter(train_cfg.model.name, model, train_cfg.model.anchor, train_cfg.image_size, device)
    return vec2box


@pytest.fixture(scope="session")
def anc2box(inference_v7_cfg: Config, model: YOLO, device) -> Anc2Box:
    anc2box = create_converter(
        inference_v7_cfg.model.name, model, inference_v7_cfg.model.anchor, inference_v7_cfg.image_size, device
    )
    return anc2box


@pytest.fixture(scope="session")
def train_dataloader(train_cfg: Config):
    prepare_dataset(train_cfg.dataset, task="train")
    return YoloDataLoader(train_cfg.task.data, train_cfg.dataset, train_cfg.task.task)


@pytest.fixture(scope="session")
def validation_dataloader(validation_cfg: Config):
    prepare_dataset(validation_cfg.dataset, task="val")
    return YoloDataLoader(validation_cfg.task.data, validation_cfg.dataset, validation_cfg.task.task)


@pytest.fixture(scope="session")
def file_stream_data_loader(inference_cfg: Config):
    return StreamDataLoader(inference_cfg.task.data)


@pytest.fixture(scope="session")
def file_stream_data_loader_v7(inference_v7_cfg: Config):
    return StreamDataLoader(inference_v7_cfg.task.data)


@pytest.fixture(scope="session")
def directory_stream_data_loader(inference_cfg: Config):
    inference_cfg.task.data.source = "tests/data/images/train"
    return StreamDataLoader(inference_cfg.task.data)