from torch.utils.data.dataset import IterableDataset from torch.utils.data import DataLoader import numpy as np from sudoku.loader import DataIterBuffer, train_dataset, test_dataset, data_loader, get_datasets class CustomDataLoader(DataLoader): def __init__(self, data_iters, batch_size): self.data_iters=data_iters self.batch_size=batch_size self.data_loaders = [iter(DataLoader(data_iter, batch_size=batch_size)) for data_iter in data_iters] def __iter__(self): while True: buffer_sizes = np.array([len(buffer) for buffer in self.data_iters]) if any(buffer_sizes>=self.batch_size): idx_yield = len(buffer_sizes)-1-np.argmax(buffer_sizes[::-1]>=self.batch_size) else : idx_yield = np.argmax(buffer_sizes) yield [idx_yield]+ next(self.data_loaders[idx_yield]) @property def num_workers(self): return 0