File size: 1,521 Bytes
31cab2b 80efe00 31cab2b a0a9c3e 31cab2b 80efe00 a0a9c3e 80efe00 a0a9c3e 80efe00 7b9ca3a a0a9c3e 7b9ca3a a0a9c3e 7b9ca3a a0a9c3e 7b9ca3a a0a9c3e 7b9ca3a a0a9c3e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 |
import sys
from pathlib import Path
import torch
from hydra import compose, initialize
from omegaconf import OmegaConf
project_root = Path(__file__).resolve().parent.parent.parent
sys.path.append(str(project_root))
from yolo.model.yolo import YOLO, get_model
config_path = "../../yolo/config"
config_name = "config"
def test_build_model():
with initialize(config_path=config_path, version_base=None):
cfg = compose(config_name=config_name)
OmegaConf.set_struct(cfg.model, False)
model = YOLO(cfg.model, 80)
assert len(model.model) == 38
def test_get_model():
with initialize(config_path=config_path, version_base=None):
cfg = compose(config_name=config_name)
model = get_model(cfg)
assert isinstance(model, YOLO)
def test_yolo_forward_output_shape():
with initialize(config_path=config_path, version_base=None):
cfg = compose(config_name=config_name)
model = get_model(cfg)
# 2 - batch size, 3 - number of channels, 640x640 - image dimensions
dummy_input = torch.rand(2, 3, 640, 640)
# Forward pass through the model
output = model(dummy_input)
output_shape = [x.shape for x in output[-1]]
assert output_shape == [
torch.Size([2, 144, 80, 80]),
torch.Size([2, 144, 40, 40]),
torch.Size([2, 144, 20, 20]),
torch.Size([2, 144, 80, 80]),
torch.Size([2, 144, 40, 40]),
torch.Size([2, 144, 20, 20]),
]
|