#!/usr/bin/env python3 import glob import os from typing import Callable, Iterator from torch import Tensor from torch.utils.data import DataLoader, Dataset, IterableDataset class CustomIterableDataset(IterableDataset): r""" An auxiliary class for iterating through a dataset. """ def __init__(self, transform_filename_to_tensor: Callable, path: str) -> None: r""" Args: transform_filename_to_tensor (callable): Function to read a data file from path and return a tensor from that file. path (str): Path to dataset files. This can be either a path to a directory or a file where input examples are stored. """ self.file_itr = None self.path = path if os.path.isdir(self.path): self.file_itr = glob.glob(self.path + "*") self.transform_filename_to_tensor = transform_filename_to_tensor def __iter__(self) -> Iterator[Tensor]: r""" Returns: iter (Iterator[Tensor]): A map from a function that processes a list of file path(s) to a list of Tensors. """ if self.file_itr is not None: return map(self.transform_filename_to_tensor, self.file_itr) else: return self.transform_filename_to_tensor(self.path) def dataset_to_dataloader(dataset: Dataset, batch_size: int = 64) -> DataLoader: r""" An auxiliary function that creates torch DataLoader from torch Dataset using input `batch_size`. Args: dataset (Dataset): A torch dataset that allows to iterate over the batches of examples. batch_size (int, optional): Batch size of for each tensor in the iteration. Returns: dataloader_iter (DataLoader): a DataLoader for data iteration. """ return DataLoader(dataset, batch_size=batch_size)