import unittest from unittest.mock import patch from torch.utils.data import DataLoader from models.helpers.dataloaders import train_dataloader, train_val_dataloader class TestDataLoader(unittest.TestCase): def test_train_dataloader(self): train_loader = train_dataloader( batch_size=2, num_workers=2, cache=False, mem_cache=False, ) # Assertions self.assertIsInstance(train_loader, DataLoader) for batch in train_loader: self.assertEqual(len(batch), 13) break def test_train_val_dataloader(self): train_loader, val_loader = train_val_dataloader( batch_size=2, num_workers=2, cache=False, mem_cache=False, ) # Assertions self.assertIsInstance(train_loader, DataLoader) self.assertIsInstance(val_loader, DataLoader) if __name__ == "__main__": unittest.main()