File size: 967 Bytes
4484b8a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29


from torch.utils.data.dataset import IterableDataset
from torch.utils.data import  DataLoader
import numpy as np

from sudoku.loader import DataIterBuffer, train_dataset, test_dataset, data_loader, get_datasets



class CustomDataLoader(DataLoader):
    def __init__(self, data_iters, batch_size):
        self.data_iters=data_iters
        self.batch_size=batch_size
        self.data_loaders = [iter(DataLoader(data_iter, batch_size=batch_size)) for data_iter in data_iters]
    def __iter__(self):
        while True:
            buffer_sizes = np.array([len(buffer) for buffer in self.data_iters])
            if any(buffer_sizes>=self.batch_size):
                idx_yield = len(buffer_sizes)-1-np.argmax(buffer_sizes[::-1]>=self.batch_size)
                
            else :
                idx_yield = np.argmax(buffer_sizes)
            yield [idx_yield]+ next(self.data_loaders[idx_yield])
    
    @property
    def num_workers(self):
        return 0