henry000 commited on
Commit
a757657
·
1 Parent(s): ac8e6e6

✅ [Pass] the test for newdataloader

Browse files
tests/test_tools/test_data_loader.py CHANGED
@@ -1,11 +1,13 @@
1
  import sys
2
  from pathlib import Path
3
 
 
 
4
  project_root = Path(__file__).resolve().parent.parent.parent
5
  sys.path.append(str(project_root))
6
 
7
  from yolo.config.config import Config
8
- from yolo.tools.data_loader import StreamDataLoader, YoloDataLoader, create_dataloader
9
 
10
 
11
  def test_create_dataloader_cache(train_cfg: Config):
@@ -25,7 +27,7 @@ def test_create_dataloader_cache(train_cfg: Config):
25
  assert m_image_paths == l_image_paths
26
 
27
 
28
- def test_training_data_loader_correctness(train_dataloader: YoloDataLoader):
29
  """Test that the training data loader produces correctly shaped data and metadata."""
30
  batch_size, images, _, reverse_tensors, image_paths = next(iter(train_dataloader))
31
  assert batch_size == 2
@@ -38,7 +40,7 @@ def test_training_data_loader_correctness(train_dataloader: YoloDataLoader):
38
  assert list(image_paths) == list(expected_paths)
39
 
40
 
41
- def test_validation_data_loader_correctness(validation_dataloader: YoloDataLoader):
42
  batch_size, images, targets, reverse_tensors, image_paths = next(iter(validation_dataloader))
43
  assert batch_size == 4
44
  assert images.shape == (4, 3, 640, 640)
 
1
  import sys
2
  from pathlib import Path
3
 
4
+ from torch.utils.data import DataLoader
5
+
6
  project_root = Path(__file__).resolve().parent.parent.parent
7
  sys.path.append(str(project_root))
8
 
9
  from yolo.config.config import Config
10
+ from yolo.tools.data_loader import StreamDataLoader, create_dataloader
11
 
12
 
13
  def test_create_dataloader_cache(train_cfg: Config):
 
27
  assert m_image_paths == l_image_paths
28
 
29
 
30
+ def test_training_data_loader_correctness(train_dataloader: DataLoader):
31
  """Test that the training data loader produces correctly shaped data and metadata."""
32
  batch_size, images, _, reverse_tensors, image_paths = next(iter(train_dataloader))
33
  assert batch_size == 2
 
40
  assert list(image_paths) == list(expected_paths)
41
 
42
 
43
+ def test_validation_data_loader_correctness(validation_dataloader: DataLoader):
44
  batch_size, images, targets, reverse_tensors, image_paths = next(iter(validation_dataloader))
45
  assert batch_size == 4
46
  assert images.shape == (4, 3, 640, 640)