henry000 commited on
Commit
2784407
·
1 Parent(s): 930952c

✅ [Add] test for helper and basic module in model

Browse files
tests/test_model/test_module.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ from pathlib import Path
3
+
4
+ import torch
5
+ from torch import nn
6
+
7
+ project_root = Path(__file__).resolve().parent.parent.parent
8
+ sys.path.append(str(project_root))
9
+ from yolo.model.module import SPPELAN, ADown, CBLinear, Conv, Pool
10
+
11
+ STRIDE = 2
12
+ KERNEL_SIZE = 3
13
+ IN_CHANNELS = 64
14
+ OUT_CHANNELS = 128
15
+ NECK_CHANNELS = 64
16
+
17
+
18
+ def test_conv():
19
+ conv = Conv(IN_CHANNELS, OUT_CHANNELS, KERNEL_SIZE)
20
+ x = torch.randn(1, IN_CHANNELS, 64, 64)
21
+ out = conv(x)
22
+ assert out.shape == (1, OUT_CHANNELS, 64, 64)
23
+
24
+
25
+ def test_pool_max():
26
+ pool = Pool("max", 2, stride=2)
27
+ x = torch.randn(1, IN_CHANNELS, 64, 64)
28
+ out = pool(x)
29
+ assert out.shape == (1, IN_CHANNELS, 32, 32)
30
+
31
+
32
+ def test_pool_avg():
33
+ pool = Pool("avg", 2, stride=2)
34
+ x = torch.randn(1, IN_CHANNELS, 64, 64)
35
+ out = pool(x)
36
+ assert out.shape == (1, IN_CHANNELS, 32, 32)
37
+
38
+
39
+ def test_adown():
40
+ adown = ADown(IN_CHANNELS, OUT_CHANNELS)
41
+ x = torch.randn(1, IN_CHANNELS, 64, 64)
42
+ out = adown(x)
43
+ assert out.shape == (1, OUT_CHANNELS, 32, 32)
44
+
45
+
46
+ def test_adown():
47
+ adown = ADown(IN_CHANNELS, OUT_CHANNELS)
48
+ x = torch.randn(1, IN_CHANNELS, 64, 64)
49
+ out = adown(x)
50
+ assert out.shape == (1, OUT_CHANNELS, 32, 32)
51
+
52
+
53
+ def test_cblinear():
54
+ cblinear = CBLinear(IN_CHANNELS, [5, 5])
55
+ x = torch.randn(1, IN_CHANNELS, 64, 64)
56
+ outs = cblinear(x)
57
+ assert len(outs) == 2
58
+ assert outs[0].shape == (1, 5, 64, 64)
59
+ assert outs[1].shape == (1, 5, 64, 64)
60
+
61
+
62
+ def test_sppelan():
63
+ sppelan = SPPELAN(IN_CHANNELS, OUT_CHANNELS, NECK_CHANNELS)
64
+ x = torch.randn(1, IN_CHANNELS, 64, 64)
65
+ out = sppelan(x)
66
+ assert out.shape == (1, OUT_CHANNELS, 64, 64)
tests/test_tools/test_module_helper.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ from pathlib import Path
3
+
4
+ import pytest
5
+ import torch
6
+ from torch import nn
7
+
8
+ project_root = Path(__file__).resolve().parent.parent.parent
9
+ sys.path.append(str(project_root))
10
+ from yolo.tools.module_helper import auto_pad, get_activation
11
+
12
+
13
+ @pytest.mark.parametrize(
14
+ "kernel_size, dilation, expected",
15
+ [
16
+ (3, 1, (1, 1)),
17
+ ((3, 3), (1, 1), (1, 1)),
18
+ (3, (2, 2), (2, 2)),
19
+ ((5, 5), 1, (2, 2)),
20
+ ((3, 5), (2, 1), (2, 2)),
21
+ ],
22
+ )
23
+ def test_auto_pad(kernel_size, dilation, expected):
24
+ assert auto_pad(kernel_size, dilation) == expected, "auto_pad does not calculate padding correctly"
25
+
26
+
27
+ @pytest.mark.parametrize(
28
+ "activation_name, expected_type",
29
+ [("ReLU", nn.ReLU), ("leakyrelu", nn.LeakyReLU), ("none", nn.Identity), (None, nn.Identity), (False, nn.Identity)],
30
+ )
31
+ def test_get_activation(activation_name, expected_type):
32
+ result = get_activation(activation_name)
33
+ assert isinstance(result, expected_type), f"get_activation does not return correct type for {activation_name}"
34
+
35
+
36
+ def test_get_activation_invalid():
37
+ with pytest.raises(ValueError):
38
+ get_activation("unsupported_activation")