File size: 718 Bytes
9016314
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import os
import pickle

def get_dataset(args, split):
    if args.dataset == 'speech':
        from .speech import Dataset
        dataset = Dataset(split, args.batch_size, args.set_size, args.mask_type)
    else:
        raise ValueError()

    return dataset
    
def cache(args, split, fname):
    if os.path.isfile(fname):
        with open(fname, 'rb') as f:
            batches = pickle.load(f)
    else:
        batches = []
        dataset = get_dataset(args, split)
        dataset.initialize()
        for _ in range(dataset.num_batches):
            batch = dataset.next_batch()
            batches.append(batch)
        with open(fname, 'wb') as f:
            pickle.dump(batches, f)

    return batches