File size: 1,478 Bytes
31cab2b
80efe00
31cab2b
a0a9c3e
31cab2b
80efe00
a0a9c3e
80efe00
 
a0a9c3e
80efe00
 
 
a0a9c3e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import sys
from pathlib import Path

import torch
from hydra import compose, initialize
from omegaconf import OmegaConf

project_root = Path(__file__).resolve().parent.parent.parent
sys.path.append(str(project_root))

from yolo.model.yolo import YOLO, get_model

config_path = "../../yolo/config/model"
config_name = "v7-base"


def test_build_model():
    with initialize(config_path=config_path, version_base=None):
        model_cfg = compose(config_name=config_name)
        OmegaConf.set_struct(model_cfg, False)
        model = YOLO(model_cfg)
        model.build_model(model_cfg.model)
        assert len(model.model) == 106


def test_get_model():
    with initialize(config_path=config_path, version_base=None):
        model_cfg = compose(config_name=config_name)
        model = get_model(model_cfg)
        assert isinstance(model, YOLO)


def test_yolo_forward_output_shape():
    with initialize(config_path=config_path, version_base=None):
        model_cfg = compose(config_name=config_name)

        model = get_model(model_cfg)
        # 2 - batch size, 3 - number of channels, 640x640 - image dimensions
        dummy_input = torch.rand(2, 3, 640, 640)

        # Forward pass through the model
        output = model(dummy_input)
        output_shape = [x.shape for x in output[-1]]
        assert output_shape == [
            torch.Size([2, 3, 20, 20, 85]),
            torch.Size([2, 3, 80, 80, 85]),
            torch.Size([2, 3, 40, 40, 85]),
        ]