File size: 3,676 Bytes
88e45b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b038f54
 
 
 
 
 
 
 
88e45b9
 
 
 
 
 
 
 
 
 
 
b038f54
88e45b9
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import sys
from pathlib import Path
from unittest.mock import MagicMock, patch

import pytest
import torch
from hydra import compose, initialize

project_root = Path(__file__).resolve().parent.parent.parent
sys.path.append(str(project_root))

from yolo.config.config import (
    Config,
    DataConfig,
    LossConfig,
    TrainConfig,
    ValidationConfig,
)
from yolo.model.yolo import YOLO, create_model
from yolo.tools.data_loader import create_dataloader
from yolo.tools.loss_functions import create_loss_function
from yolo.tools.solver import (  # Adjust the import to your module
    ModelTester,
    ModelTrainer,
    ModelValidator,
)
from yolo.utils.bounding_box_utils import Vec2Box
from yolo.utils.logging_utils import ProgressLogger
from yolo.utils.model_utils import (
    ExponentialMovingAverage,
    create_optimizer,
    create_scheduler,
)


@pytest.fixture
def cfg() -> Config:
    with initialize(config_path="../../yolo/config", version_base=None):
        cfg: Config = compose(config_name="config")
        cfg.weight = None
    return cfg


@pytest.fixture
def cfg_validaion() -> Config:
    with initialize(config_path="../../yolo/config", version_base=None):
        cfg: Config = compose(config_name="config", overrides=["task=validation"])
        cfg.weight = None
    return cfg


@pytest.fixture
def cfg_inference() -> Config:
    with initialize(config_path="../../yolo/config", version_base=None):
        cfg: Config = compose(config_name="config", overrides=["task=inference"])
        cfg.weight = None
    return cfg


@pytest.fixture
def device() -> torch.device:
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    return device


@pytest.fixture
def model(cfg: Config, device) -> YOLO:
    model = create_model(cfg.model, weight_path=None)
    return model.to(device)


@pytest.fixture
def vec2box(cfg: Config, model: YOLO, device) -> Vec2Box:
    model = create_model(cfg.model, weight_path=None).to(device)
    vec2box = Vec2Box(model, cfg.image_size, device)
    return vec2box


@pytest.fixture
def progress_logger(cfg: Config):
    progress_logger = ProgressLogger(cfg, exp_name=cfg.name)
    return progress_logger


# def test_model_trainer_initialization(cfg: Config, model: YOLO, vec2box: Vec2Box, progress_logger, device):
#     trainer = ModelTrainer(cfg, model, vec2box, progress_logger, device, use_ddp=False)
#     assert trainer.model == model
#     assert trainer.device == device
#     assert trainer.optimizer is not None
#     assert trainer.scheduler is not None
#     assert trainer.loss_fn is not None
#     assert trainer.progress == progress_logger


# def test_model_trainer_train_one_batch(config, model, vec2box, progress_logger, device):
#     trainer = ModelTrainer(config, model, vec2box, progress_logger, device, use_ddp=False)
#     images = torch.rand(1, 3, 224, 224)
#     targets = torch.rand(1, 5)
#     loss_item = trainer.train_one_batch(images, targets)
#     assert isinstance(loss_item, dict)


def test_model_validator_initialization(cfg_validaion: Config, model: YOLO, vec2box: Vec2Box, progress_logger, device):
    validator = ModelValidator(cfg_validaion.task, cfg_validaion.dataset, model, vec2box, progress_logger, device)
    assert validator.model == model
    assert validator.device == device
    assert validator.progress == progress_logger


def test_model_tester_initialization(cfg_inference: Config, model: YOLO, vec2box: Vec2Box, progress_logger, device):
    tester = ModelTester(cfg_inference, model, vec2box, progress_logger, device)
    assert tester.model == model
    assert tester.device == device
    assert tester.progress == progress_logger