π [Merge] branch 'main' into MODEL
Browse files- .pre-commit-config.yaml +1 -1
- tests/test_model/test_yolo.py +17 -15
.pre-commit-config.yaml
CHANGED
@@ -11,4 +11,4 @@ repos:
|
|
11 |
rev: v5.10.1 # Use the appropriate version or "stable" for the latest stable release
|
12 |
hooks:
|
13 |
- id: isort
|
14 |
-
args: ["--profile", "black"
|
|
|
11 |
rev: v5.10.1 # Use the appropriate version or "stable" for the latest stable release
|
12 |
hooks:
|
13 |
- id: isort
|
14 |
+
args: ["--profile", "black"]
|
tests/test_model/test_yolo.py
CHANGED
@@ -10,31 +10,30 @@ sys.path.append(str(project_root))
|
|
10 |
|
11 |
from yolo.model.yolo import YOLO, get_model
|
12 |
|
13 |
-
config_path = "../../yolo/config
|
14 |
-
config_name = "
|
15 |
|
16 |
|
17 |
def test_build_model():
|
18 |
with initialize(config_path=config_path, version_base=None):
|
19 |
-
|
20 |
-
|
21 |
-
model
|
22 |
-
model
|
23 |
-
assert len(model.model) ==
|
24 |
|
25 |
|
26 |
def test_get_model():
|
27 |
with initialize(config_path=config_path, version_base=None):
|
28 |
-
|
29 |
-
model = get_model(
|
30 |
assert isinstance(model, YOLO)
|
31 |
|
32 |
|
33 |
def test_yolo_forward_output_shape():
|
34 |
with initialize(config_path=config_path, version_base=None):
|
35 |
-
|
36 |
-
|
37 |
-
model = get_model(model_cfg)
|
38 |
# 2 - batch size, 3 - number of channels, 640x640 - image dimensions
|
39 |
dummy_input = torch.rand(2, 3, 640, 640)
|
40 |
|
@@ -42,7 +41,10 @@ def test_yolo_forward_output_shape():
|
|
42 |
output = model(dummy_input)
|
43 |
output_shape = [x.shape for x in output[-1]]
|
44 |
assert output_shape == [
|
45 |
-
torch.Size([2,
|
46 |
-
torch.Size([2,
|
47 |
-
torch.Size([2,
|
|
|
|
|
|
|
48 |
]
|
|
|
10 |
|
11 |
from yolo.model.yolo import YOLO, get_model
|
12 |
|
13 |
+
config_path = "../../yolo/config"
|
14 |
+
config_name = "config"
|
15 |
|
16 |
|
17 |
def test_build_model():
|
18 |
with initialize(config_path=config_path, version_base=None):
|
19 |
+
cfg = compose(config_name=config_name)
|
20 |
+
|
21 |
+
OmegaConf.set_struct(cfg.model, False)
|
22 |
+
model = YOLO(cfg.model, 80)
|
23 |
+
assert len(model.model) == 38
|
24 |
|
25 |
|
26 |
def test_get_model():
|
27 |
with initialize(config_path=config_path, version_base=None):
|
28 |
+
cfg = compose(config_name=config_name)
|
29 |
+
model = get_model(cfg)
|
30 |
assert isinstance(model, YOLO)
|
31 |
|
32 |
|
33 |
def test_yolo_forward_output_shape():
|
34 |
with initialize(config_path=config_path, version_base=None):
|
35 |
+
cfg = compose(config_name=config_name)
|
36 |
+
model = get_model(cfg)
|
|
|
37 |
# 2 - batch size, 3 - number of channels, 640x640 - image dimensions
|
38 |
dummy_input = torch.rand(2, 3, 640, 640)
|
39 |
|
|
|
41 |
output = model(dummy_input)
|
42 |
output_shape = [x.shape for x in output[-1]]
|
43 |
assert output_shape == [
|
44 |
+
torch.Size([2, 144, 80, 80]),
|
45 |
+
torch.Size([2, 144, 40, 40]),
|
46 |
+
torch.Size([2, 144, 20, 20]),
|
47 |
+
torch.Size([2, 144, 80, 80]),
|
48 |
+
torch.Size([2, 144, 40, 40]),
|
49 |
+
torch.Size([2, 144, 20, 20]),
|
50 |
]
|