Sebastien
first commit
4484b8a
raw
history blame contribute delete
967 Bytes
from torch.utils.data.dataset import IterableDataset
from torch.utils.data import DataLoader
import numpy as np
from sudoku.loader import DataIterBuffer, train_dataset, test_dataset, data_loader, get_datasets
class CustomDataLoader(DataLoader):
def __init__(self, data_iters, batch_size):
self.data_iters=data_iters
self.batch_size=batch_size
self.data_loaders = [iter(DataLoader(data_iter, batch_size=batch_size)) for data_iter in data_iters]
def __iter__(self):
while True:
buffer_sizes = np.array([len(buffer) for buffer in self.data_iters])
if any(buffer_sizes>=self.batch_size):
idx_yield = len(buffer_sizes)-1-np.argmax(buffer_sizes[::-1]>=self.batch_size)
else :
idx_yield = np.argmax(buffer_sizes)
yield [idx_yield]+ next(self.data_loaders[idx_yield])
@property
def num_workers(self):
return 0