✅ [Pass] the test for newdataloader
Browse files
tests/test_tools/test_data_loader.py
CHANGED
@@ -1,11 +1,13 @@
|
|
1 |
import sys
|
2 |
from pathlib import Path
|
3 |
|
|
|
|
|
4 |
project_root = Path(__file__).resolve().parent.parent.parent
|
5 |
sys.path.append(str(project_root))
|
6 |
|
7 |
from yolo.config.config import Config
|
8 |
-
from yolo.tools.data_loader import StreamDataLoader,
|
9 |
|
10 |
|
11 |
def test_create_dataloader_cache(train_cfg: Config):
|
@@ -25,7 +27,7 @@ def test_create_dataloader_cache(train_cfg: Config):
|
|
25 |
assert m_image_paths == l_image_paths
|
26 |
|
27 |
|
28 |
-
def test_training_data_loader_correctness(train_dataloader:
|
29 |
"""Test that the training data loader produces correctly shaped data and metadata."""
|
30 |
batch_size, images, _, reverse_tensors, image_paths = next(iter(train_dataloader))
|
31 |
assert batch_size == 2
|
@@ -38,7 +40,7 @@ def test_training_data_loader_correctness(train_dataloader: YoloDataLoader):
|
|
38 |
assert list(image_paths) == list(expected_paths)
|
39 |
|
40 |
|
41 |
-
def test_validation_data_loader_correctness(validation_dataloader:
|
42 |
batch_size, images, targets, reverse_tensors, image_paths = next(iter(validation_dataloader))
|
43 |
assert batch_size == 4
|
44 |
assert images.shape == (4, 3, 640, 640)
|
|
|
1 |
import sys
|
2 |
from pathlib import Path
|
3 |
|
4 |
+
from torch.utils.data import DataLoader
|
5 |
+
|
6 |
project_root = Path(__file__).resolve().parent.parent.parent
|
7 |
sys.path.append(str(project_root))
|
8 |
|
9 |
from yolo.config.config import Config
|
10 |
+
from yolo.tools.data_loader import StreamDataLoader, create_dataloader
|
11 |
|
12 |
|
13 |
def test_create_dataloader_cache(train_cfg: Config):
|
|
|
27 |
assert m_image_paths == l_image_paths
|
28 |
|
29 |
|
30 |
+
def test_training_data_loader_correctness(train_dataloader: DataLoader):
|
31 |
"""Test that the training data loader produces correctly shaped data and metadata."""
|
32 |
batch_size, images, _, reverse_tensors, image_paths = next(iter(train_dataloader))
|
33 |
assert batch_size == 2
|
|
|
40 |
assert list(image_paths) == list(expected_paths)
|
41 |
|
42 |
|
43 |
+
def test_validation_data_loader_correctness(validation_dataloader: DataLoader):
|
44 |
batch_size, images, targets, reverse_tensors, image_paths = next(iter(validation_dataloader))
|
45 |
assert batch_size == 4
|
46 |
assert images.shape == (4, 3, 640, 640)
|