Spaces:
Running
Running
import unittest | |
from unittest.mock import patch | |
from torch.utils.data import DataLoader | |
from models.helpers.dataloaders import train_dataloader, train_val_dataloader | |
class TestDataLoader(unittest.TestCase): | |
def test_train_dataloader(self): | |
train_loader = train_dataloader( | |
batch_size=2, | |
num_workers=2, | |
cache=False, | |
mem_cache=False, | |
) | |
# Assertions | |
self.assertIsInstance(train_loader, DataLoader) | |
for batch in train_loader: | |
self.assertEqual(len(batch), 13) | |
break | |
def test_train_val_dataloader(self): | |
train_loader, val_loader = train_val_dataloader( | |
batch_size=2, | |
num_workers=2, | |
cache=False, | |
mem_cache=False, | |
) | |
# Assertions | |
self.assertIsInstance(train_loader, DataLoader) | |
self.assertIsInstance(val_loader, DataLoader) | |
if __name__ == "__main__": | |
unittest.main() | |