File size: 1,919 Bytes
d61b9c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
#!/usr/bin/env python3

import glob
import os
from typing import Callable, Iterator

from torch import Tensor
from torch.utils.data import DataLoader, Dataset, IterableDataset


class CustomIterableDataset(IterableDataset):
    r"""
    An auxiliary class for iterating through a dataset.
    """

    def __init__(self, transform_filename_to_tensor: Callable, path: str) -> None:
        r"""
        Args:
            transform_filename_to_tensor (callable): Function to read a data
                        file from path and return a tensor from that file.
            path (str): Path to dataset files. This can be either a path to a
                        directory or a file where input examples are stored.
        """
        self.file_itr = None
        self.path = path

        if os.path.isdir(self.path):
            self.file_itr = glob.glob(self.path + "*")

        self.transform_filename_to_tensor = transform_filename_to_tensor

    def __iter__(self) -> Iterator[Tensor]:
        r"""
        Returns:
            iter (Iterator[Tensor]): A map from a function that
                processes a list of file path(s) to a list of Tensors.
        """
        if self.file_itr is not None:
            return map(self.transform_filename_to_tensor, self.file_itr)
        else:
            return self.transform_filename_to_tensor(self.path)


def dataset_to_dataloader(dataset: Dataset, batch_size: int = 64) -> DataLoader:
    r"""
    An auxiliary function that creates torch DataLoader from torch Dataset
    using input `batch_size`.

    Args:
        dataset (Dataset): A torch dataset that allows to iterate over
            the batches of examples.
        batch_size (int, optional): Batch size of for each tensor in the
            iteration.

    Returns:
        dataloader_iter (DataLoader): a DataLoader for data iteration.
    """

    return DataLoader(dataset, batch_size=batch_size)