File size: 3,794 Bytes
6a39ae1
 
 
 
 
 
ac8e6e6
6a39ae1
 
 
 
4b46de4
6a39ae1
ac8e6e6
44abd6c
ac8e6e6
6a39ae1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4b46de4
 
 
 
 
6a39ae1
 
 
 
 
 
 
 
 
 
 
4b46de4
 
 
 
 
 
ac8e6e6
 
3ebbbd9
 
ac8e6e6
bd0409b
ac8e6e6
 
 
 
 
 
 
3ebbbd9
ac8e6e6
 
 
 
6a39ae1
 
4b46de4
6a39ae1
 
 
4b46de4
 
 
 
 
 
 
 
6a39ae1
 
44abd6c
ac8e6e6
6a39ae1
 
 
 
44abd6c
ac8e6e6
6a39ae1
 
 
 
 
 
 
4b46de4
 
 
 
 
6a39ae1
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import sys
from pathlib import Path

import pytest
import torch
from hydra import compose, initialize
from lightning import Trainer

project_root = Path(__file__).resolve().parent.parent
sys.path.append(str(project_root))

from yolo import Anc2Box, Config, Vec2Box, create_converter, create_model
from yolo.model.yolo import YOLO
from yolo.tools.data_loader import StreamDataLoader, create_dataloader
from yolo.tools.dataset_preparation import prepare_dataset
from yolo.utils.logging_utils import set_seed, setup


def pytest_configure(config):
    config.addinivalue_line("markers", "requires_cuda: mark test to run only if CUDA is available")


def get_cfg(overrides=[]) -> Config:
    config_path = "../yolo/config"
    with initialize(config_path=config_path, version_base=None):
        cfg: Config = compose(config_name="config", overrides=overrides)
        set_seed(cfg.lucky_number)
        return cfg


@pytest.fixture(scope="session")
def train_cfg() -> Config:
    return get_cfg(overrides=["task=train", "dataset=mock"])


@pytest.fixture(scope="session")
def validation_cfg():
    return get_cfg(overrides=["task=validation", "dataset=mock"])


@pytest.fixture(scope="session")
def inference_cfg():
    return get_cfg(overrides=["task=inference"])


@pytest.fixture(scope="session")
def inference_v7_cfg():
    return get_cfg(overrides=["task=inference", "model=v7"])


@pytest.fixture(scope="session")
def device():
    return torch.device("cuda" if torch.cuda.is_available() else "cpu")


@pytest.fixture(scope="session")
def model(train_cfg: Config, device) -> YOLO:
    model = create_model(train_cfg.model)
    return model.to(device)


@pytest.fixture(scope="session")
def model_v7(inference_v7_cfg: Config, device) -> YOLO:
    model = create_model(inference_v7_cfg.model)
    return model.to(device)


@pytest.fixture(scope="session")
def solver(train_cfg: Config) -> Trainer:
    train_cfg.use_wandb = False
    callbacks, loggers, save_path = setup(train_cfg)
    trainer = Trainer(
        accelerator="auto",
        max_epochs=getattr(train_cfg.task, "epoch", None),
        precision="16-mixed",
        callbacks=callbacks,
        logger=loggers,
        log_every_n_steps=1,
        gradient_clip_val=10,
        deterministic=True,
        default_root_dir=save_path,
    )
    return trainer


@pytest.fixture(scope="session")
def vec2box(train_cfg: Config, model: YOLO, device) -> Vec2Box:
    vec2box = create_converter(train_cfg.model.name, model, train_cfg.model.anchor, train_cfg.image_size, device)
    return vec2box


@pytest.fixture(scope="session")
def anc2box(inference_v7_cfg: Config, model: YOLO, device) -> Anc2Box:
    anc2box = create_converter(
        inference_v7_cfg.model.name, model, inference_v7_cfg.model.anchor, inference_v7_cfg.image_size, device
    )
    return anc2box


@pytest.fixture(scope="session")
def train_dataloader(train_cfg: Config):
    prepare_dataset(train_cfg.dataset, task="train")
    return create_dataloader(train_cfg.task.data, train_cfg.dataset, train_cfg.task.task)


@pytest.fixture(scope="session")
def validation_dataloader(validation_cfg: Config):
    prepare_dataset(validation_cfg.dataset, task="val")
    return create_dataloader(validation_cfg.task.data, validation_cfg.dataset, validation_cfg.task.task)


@pytest.fixture(scope="session")
def file_stream_data_loader(inference_cfg: Config):
    return StreamDataLoader(inference_cfg.task.data)


@pytest.fixture(scope="session")
def file_stream_data_loader_v7(inference_v7_cfg: Config):
    return StreamDataLoader(inference_v7_cfg.task.data)


@pytest.fixture(scope="session")
def directory_stream_data_loader(inference_cfg: Config):
    inference_cfg.task.data.source = "tests/data/images/train"
    return StreamDataLoader(inference_cfg.task.data)