import sys from pathlib import Path from torch.utils.data import DataLoader project_root = Path(__file__).resolve().parent.parent.parent sys.path.append(str(project_root)) from yolo.config.config import Config from yolo.tools.data_loader import StreamDataLoader, create_dataloader def test_create_dataloader_cache(train_cfg: Config): train_cfg.task.data.shuffle = False train_cfg.task.data.batch_size = 2 cache_file = Path("tests/data/train.cache") cache_file.unlink(missing_ok=True) make_cache_loader = create_dataloader(train_cfg.task.data, train_cfg.dataset) load_cache_loader = create_dataloader(train_cfg.task.data, train_cfg.dataset) m_batch_size, m_images, _, m_reverse_tensors, m_image_paths = next(iter(make_cache_loader)) l_batch_size, l_images, _, l_reverse_tensors, l_image_paths = next(iter(load_cache_loader)) assert m_batch_size == l_batch_size assert m_images.shape == l_images.shape assert m_reverse_tensors.shape == l_reverse_tensors.shape assert m_image_paths == l_image_paths def test_training_data_loader_correctness(train_dataloader: DataLoader): """Test that the training data loader produces correctly shaped data and metadata.""" batch_size, images, _, reverse_tensors, image_paths = next(iter(train_dataloader)) assert batch_size == 2 assert images.shape == (2, 3, 640, 640) assert reverse_tensors.shape == (2, 5) expected_paths = [ Path("tests/data/images/train/000000050725.jpg"), Path("tests/data/images/train/000000167848.jpg"), ] assert list(image_paths) == list(expected_paths) def test_validation_data_loader_correctness(validation_dataloader: DataLoader): batch_size, images, targets, reverse_tensors, image_paths = next(iter(validation_dataloader)) assert batch_size == 5 assert images.shape == (5, 3, 640, 640) assert targets.shape == (5, 18, 5) assert reverse_tensors.shape == (5, 5) expected_paths = [ Path("tests/data/images/val/000000151480.jpg"), Path("tests/data/images/val/000000284106.jpg"), Path("tests/data/images/val/000000323571.jpg"), Path("tests/data/images/val/000000556498.jpg"), Path("tests/data/images/val/000000570456.jpg"), ] assert list(image_paths) == list(expected_paths) def test_file_stream_data_loader_frame(file_stream_data_loader: StreamDataLoader): """Test the frame output from the file stream data loader.""" frame, rev_tensor, origin_frame = next(iter(file_stream_data_loader)) assert frame.shape == (1, 3, 640, 640) assert rev_tensor.shape == (1, 5) assert origin_frame.size == (1024, 768) def test_directory_stream_data_loader_frame(directory_stream_data_loader: StreamDataLoader): """Test the frame output from the directory stream data loader.""" frame, rev_tensor, origin_frame = next(iter(directory_stream_data_loader)) assert frame.shape == (1, 3, 640, 640) assert rev_tensor.shape == (1, 5) assert origin_frame.size != (640, 640)