File size: 1,521 Bytes
31cab2b
80efe00
31cab2b
a0a9c3e
31cab2b
80efe00
a0a9c3e
80efe00
 
a0a9c3e
80efe00
 
7b9ca3a
 
a0a9c3e
 
 
 
7b9ca3a
 
 
 
 
a0a9c3e
 
 
 
7b9ca3a
 
a0a9c3e
 
 
 
 
7b9ca3a
 
a0a9c3e
 
 
 
 
 
 
7b9ca3a
 
 
 
 
 
a0a9c3e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import sys
from pathlib import Path

import torch
from hydra import compose, initialize
from omegaconf import OmegaConf

project_root = Path(__file__).resolve().parent.parent.parent
sys.path.append(str(project_root))

from yolo.model.yolo import YOLO, get_model

config_path = "../../yolo/config"
config_name = "config"


def test_build_model():
    with initialize(config_path=config_path, version_base=None):
        cfg = compose(config_name=config_name)

        OmegaConf.set_struct(cfg.model, False)
        model = YOLO(cfg.model, 80)
        assert len(model.model) == 38


def test_get_model():
    with initialize(config_path=config_path, version_base=None):
        cfg = compose(config_name=config_name)
        model = get_model(cfg)
        assert isinstance(model, YOLO)


def test_yolo_forward_output_shape():
    with initialize(config_path=config_path, version_base=None):
        cfg = compose(config_name=config_name)
        model = get_model(cfg)
        # 2 - batch size, 3 - number of channels, 640x640 - image dimensions
        dummy_input = torch.rand(2, 3, 640, 640)

        # Forward pass through the model
        output = model(dummy_input)
        output_shape = [x.shape for x in output[-1]]
        assert output_shape == [
            torch.Size([2, 144, 80, 80]),
            torch.Size([2, 144, 40, 40]),
            torch.Size([2, 144, 20, 20]),
            torch.Size([2, 144, 80, 80]),
            torch.Size([2, 144, 40, 40]),
            torch.Size([2, 144, 20, 20]),
        ]