henry000 commited on
Commit
25d3c1c
·
1 Parent(s): 542860e

✅ [Add] pytest for load model, config

Browse files
tests/model/test_yolo.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+ import torch
3
+ import sys
4
+
5
+ sys.path.append("./")
6
+ from utils.tools import load_model_cfg
7
+ from model.yolo import YOLO
8
+
9
+
10
+ def test_load_model_configuration():
11
+ config_name = "v7-base"
12
+ model_cfg = load_model_cfg(config_name)
13
+
14
+ assert "model" in model_cfg
15
+ assert isinstance(model_cfg, dict)
16
+
17
+
18
+ def test_build_model():
19
+ config_name = "v7-base"
20
+ model_cfg = load_model_cfg(config_name)
21
+ model = YOLO(model_cfg)
22
+ model.build_model(model_cfg["model"])
23
+ assert len(model.model) == 106
24
+
25
+
26
+ def test_yolo_forward_output_shape():
27
+ config_name = "v7-base"
28
+ model_cfg = load_model_cfg(config_name)
29
+ model = YOLO(model_cfg)
30
+
31
+ # 2 - batch size, 3 - number of channels, 640x640 - image dimensions
32
+ dummy_input = torch.rand(2, 3, 640, 640)
33
+
34
+ # Forward pass through the model
35
+ output = model(dummy_input)
36
+ output_shape = [x.shape for x in output]
37
+ assert output_shape == [
38
+ torch.Size([2, 3, 20, 20, 85]),
39
+ torch.Size([2, 3, 80, 80, 85]),
40
+ torch.Size([2, 3, 40, 40, 85]),
41
+ ]
tests/test_utils/test_tools.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pytest
3
+ from yaml import YAMLError
4
+ from unittest.mock import mock_open, patch
5
+
6
+ import sys
7
+
8
+ sys.path.append("./")
9
+ from utils.tools import complete_path, load_model_cfg
10
+
11
+
12
+ # Test for complete_path function
13
+ def test_complete_path():
14
+ assert complete_path() == "./config/model/v7-base.yaml"
15
+
16
+ assert complete_path("test") == "./config/model/test.yaml"
17
+
18
+ assert complete_path("test.yaml") == "./config/model/test.yaml"
19
+
20
+
21
+ # Test for load_model_cfg function
22
+ def test_load_model_cfg_success():
23
+ test_yaml_content = """
24
+ nc: 80
25
+ model:
26
+ type: "yolo"
27
+ """
28
+ with patch("builtins.open", mock_open(read_data=test_yaml_content)):
29
+ with patch("os.path.exists", return_value=True):
30
+ result = load_model_cfg("v7-base.yaml")
31
+ assert result["nc"] == 80
32
+ assert result["model"]["type"] == "yolo"
33
+
34
+
35
+ def test_load_model_cfg_file_not_found():
36
+ with patch("os.path.exists", return_value=False):
37
+ with pytest.raises(FileNotFoundError):
38
+ load_model_cfg("missing-file.yaml")
39
+
40
+
41
+ def test_load_model_cfg_yaml_error():
42
+ # Simulating a YAML error
43
+ with patch("builtins.open", mock_open()) as mocked_file:
44
+ mocked_file.side_effect = YAMLError("error parsing YAML")
45
+ with pytest.raises(YAMLError):
46
+ load_model_cfg("corrupt-file.yaml")
47
+
48
+
49
+ def test_load_model_cfg_missing_keys():
50
+ test_yaml_content = """
51
+ nc: 80
52
+ """
53
+ with patch("builtins.open", mock_open(read_data=test_yaml_content)):
54
+ with patch("os.path.exists", return_value=True):
55
+ with pytest.raises(ValueError) as exc_info:
56
+ load_model_cfg("incomplete-model.yaml")
57
+ assert str(exc_info.value) == "Missing required key: 'model'"