✅ [Pass] and add tests for yolo.py
Browse files- tests/model/test_yolo.py +0 -41
- tests/test_model/test_yolo.py +48 -0
- tests/test_utils/test_dataaugment.py +63 -0
tests/model/test_yolo.py
DELETED
@@ -1,41 +0,0 @@
|
|
1 |
-
import pytest
|
2 |
-
import torch
|
3 |
-
import sys
|
4 |
-
|
5 |
-
sys.path.append("./")
|
6 |
-
from utils.tools import load_model_cfg
|
7 |
-
from model.yolo import YOLO
|
8 |
-
|
9 |
-
|
10 |
-
def test_load_model_configuration():
|
11 |
-
config_name = "v7-base"
|
12 |
-
model_cfg = load_model_cfg(config_name)
|
13 |
-
|
14 |
-
assert "model" in model_cfg
|
15 |
-
assert isinstance(model_cfg, dict)
|
16 |
-
|
17 |
-
|
18 |
-
def test_build_model():
|
19 |
-
config_name = "v7-base"
|
20 |
-
model_cfg = load_model_cfg(config_name)
|
21 |
-
model = YOLO(model_cfg)
|
22 |
-
model.build_model(model_cfg["model"])
|
23 |
-
assert len(model.model) == 106
|
24 |
-
|
25 |
-
|
26 |
-
def test_yolo_forward_output_shape():
|
27 |
-
config_name = "v7-base"
|
28 |
-
model_cfg = load_model_cfg(config_name)
|
29 |
-
model = YOLO(model_cfg)
|
30 |
-
|
31 |
-
# 2 - batch size, 3 - number of channels, 640x640 - image dimensions
|
32 |
-
dummy_input = torch.rand(2, 3, 640, 640)
|
33 |
-
|
34 |
-
# Forward pass through the model
|
35 |
-
output = model(dummy_input)
|
36 |
-
output_shape = [x.shape for x in output[-1]]
|
37 |
-
assert output_shape == [
|
38 |
-
torch.Size([2, 3, 20, 20, 85]),
|
39 |
-
torch.Size([2, 3, 80, 80, 85]),
|
40 |
-
torch.Size([2, 3, 40, 40, 85]),
|
41 |
-
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tests/test_model/test_yolo.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pytest
|
2 |
+
import torch
|
3 |
+
from hydra import initialize, compose
|
4 |
+
from hydra.core.global_hydra import GlobalHydra
|
5 |
+
from omegaconf import DictConfig, OmegaConf
|
6 |
+
|
7 |
+
import sys
|
8 |
+
|
9 |
+
sys.path.append("./")
|
10 |
+
from model.yolo import YOLO, get_model
|
11 |
+
|
12 |
+
config_path = "../../config/model"
|
13 |
+
config_name = "v7-base"
|
14 |
+
|
15 |
+
|
16 |
+
def test_build_model():
|
17 |
+
|
18 |
+
with initialize(config_path=config_path, version_base=None):
|
19 |
+
model_cfg = compose(config_name=config_name)
|
20 |
+
OmegaConf.set_struct(model_cfg, False)
|
21 |
+
model = YOLO(model_cfg)
|
22 |
+
model.build_model(model_cfg.model)
|
23 |
+
assert len(model.model) == 106
|
24 |
+
|
25 |
+
|
26 |
+
def test_get_model():
|
27 |
+
with initialize(config_path=config_path, version_base=None):
|
28 |
+
model_cfg = compose(config_name=config_name)
|
29 |
+
model = get_model(model_cfg)
|
30 |
+
assert isinstance(model, YOLO)
|
31 |
+
|
32 |
+
|
33 |
+
def test_yolo_forward_output_shape():
|
34 |
+
with initialize(config_path=config_path, version_base=None):
|
35 |
+
model_cfg = compose(config_name=config_name)
|
36 |
+
|
37 |
+
model = get_model(model_cfg)
|
38 |
+
# 2 - batch size, 3 - number of channels, 640x640 - image dimensions
|
39 |
+
dummy_input = torch.rand(2, 3, 640, 640)
|
40 |
+
|
41 |
+
# Forward pass through the model
|
42 |
+
output = model(dummy_input)
|
43 |
+
output_shape = [x.shape for x in output[-1]]
|
44 |
+
assert output_shape == [
|
45 |
+
torch.Size([2, 3, 20, 20, 85]),
|
46 |
+
torch.Size([2, 3, 80, 80, 85]),
|
47 |
+
torch.Size([2, 3, 40, 40, 85]),
|
48 |
+
]
|
tests/test_utils/test_dataaugment.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pytest
|
2 |
+
import torch
|
3 |
+
from PIL import Image
|
4 |
+
from torchvision.transforms import functional as TF
|
5 |
+
import sys
|
6 |
+
|
7 |
+
sys.path.append("./")
|
8 |
+
from utils.dataargument import RandomHorizontalFlip, Compose, Mosaic
|
9 |
+
|
10 |
+
|
11 |
+
def test_random_horizontal_flip():
|
12 |
+
# Create a mock image and bounding boxes
|
13 |
+
img = Image.new("RGB", (100, 100), color="red")
|
14 |
+
boxes = torch.tensor([[1, 0.1, 0.1, 0.9, 0.9]]) # class, xmin, ymin, xmax, ymax
|
15 |
+
|
16 |
+
flip_transform = RandomHorizontalFlip(prob=1) # Set probability to 1 to ensure flip
|
17 |
+
flipped_img, flipped_boxes = flip_transform(img, boxes)
|
18 |
+
|
19 |
+
# Assert image is flipped by comparing it to a manually flipped image
|
20 |
+
assert TF.hflip(img) == flipped_img
|
21 |
+
|
22 |
+
# Assert bounding boxes are flipped correctly
|
23 |
+
expected_boxes = torch.tensor([[1, 0.1, 0.1, 0.9, 0.9]])
|
24 |
+
expected_boxes[:, [1, 3]] = 1 - expected_boxes[:, [3, 1]]
|
25 |
+
assert torch.allclose(flipped_boxes, expected_boxes), "Bounding boxes were not flipped correctly"
|
26 |
+
|
27 |
+
|
28 |
+
def test_compose():
|
29 |
+
# Define two mock transforms that simply return the inputs
|
30 |
+
def mock_transform(image, boxes):
|
31 |
+
return image, boxes
|
32 |
+
|
33 |
+
compose = Compose([mock_transform, mock_transform])
|
34 |
+
img = Image.new("RGB", (10, 10), color="blue")
|
35 |
+
boxes = torch.tensor([[0, 0.2, 0.2, 0.8, 0.8]])
|
36 |
+
|
37 |
+
transformed_img, transformed_boxes = compose(img, boxes)
|
38 |
+
|
39 |
+
assert transformed_img == img, "Image should not be altered"
|
40 |
+
assert torch.equal(transformed_boxes, boxes), "Boxes should not be altered"
|
41 |
+
|
42 |
+
|
43 |
+
def test_mosaic():
|
44 |
+
img = Image.new("RGB", (100, 100), color="green")
|
45 |
+
boxes = torch.tensor([[0, 0.25, 0.25, 0.75, 0.75]])
|
46 |
+
|
47 |
+
# Mock parent with image_size and get_more_data method
|
48 |
+
class MockParent:
|
49 |
+
image_size = 100
|
50 |
+
|
51 |
+
def get_more_data(self, num_images):
|
52 |
+
return [(img, boxes) for _ in range(num_images)]
|
53 |
+
|
54 |
+
mosaic = Mosaic(prob=1) # Ensure mosaic is applied
|
55 |
+
mosaic.set_parent(MockParent())
|
56 |
+
|
57 |
+
mosaic_img, mosaic_boxes = mosaic(img, boxes)
|
58 |
+
|
59 |
+
# Checks here would depend on the exact expected behavior of the mosaic function,
|
60 |
+
# such as dimensions and content of the output image and boxes.
|
61 |
+
|
62 |
+
assert mosaic_img.size == (200, 200), "Mosaic image size should be doubled"
|
63 |
+
assert len(mosaic_boxes) > 0, "Should have some bounding boxes"
|