|
|
|
|
|
from torch.utils.data.dataset import IterableDataset |
|
from torch.utils.data import DataLoader |
|
import numpy as np |
|
|
|
from sudoku.loader import DataIterBuffer, train_dataset, test_dataset, data_loader, get_datasets |
|
|
|
|
|
|
|
class CustomDataLoader(DataLoader): |
|
def __init__(self, data_iters, batch_size): |
|
self.data_iters=data_iters |
|
self.batch_size=batch_size |
|
self.data_loaders = [iter(DataLoader(data_iter, batch_size=batch_size)) for data_iter in data_iters] |
|
def __iter__(self): |
|
while True: |
|
buffer_sizes = np.array([len(buffer) for buffer in self.data_iters]) |
|
if any(buffer_sizes>=self.batch_size): |
|
idx_yield = len(buffer_sizes)-1-np.argmax(buffer_sizes[::-1]>=self.batch_size) |
|
|
|
else : |
|
idx_yield = np.argmax(buffer_sizes) |
|
yield [idx_yield]+ next(self.data_loaders[idx_yield]) |
|
|
|
@property |
|
def num_workers(self): |
|
return 0 |
|
|