strexp / captum /concept /_utils /data_iterator.py
markytools's picture
added strexp
d61b9c7
raw
history blame
1.92 kB
#!/usr/bin/env python3
import glob
import os
from typing import Callable, Iterator
from torch import Tensor
from torch.utils.data import DataLoader, Dataset, IterableDataset
class CustomIterableDataset(IterableDataset):
r"""
An auxiliary class for iterating through a dataset.
"""
def __init__(self, transform_filename_to_tensor: Callable, path: str) -> None:
r"""
Args:
transform_filename_to_tensor (callable): Function to read a data
file from path and return a tensor from that file.
path (str): Path to dataset files. This can be either a path to a
directory or a file where input examples are stored.
"""
self.file_itr = None
self.path = path
if os.path.isdir(self.path):
self.file_itr = glob.glob(self.path + "*")
self.transform_filename_to_tensor = transform_filename_to_tensor
def __iter__(self) -> Iterator[Tensor]:
r"""
Returns:
iter (Iterator[Tensor]): A map from a function that
processes a list of file path(s) to a list of Tensors.
"""
if self.file_itr is not None:
return map(self.transform_filename_to_tensor, self.file_itr)
else:
return self.transform_filename_to_tensor(self.path)
def dataset_to_dataloader(dataset: Dataset, batch_size: int = 64) -> DataLoader:
r"""
An auxiliary function that creates torch DataLoader from torch Dataset
using input `batch_size`.
Args:
dataset (Dataset): A torch dataset that allows to iterate over
the batches of examples.
batch_size (int, optional): Batch size of for each tensor in the
iteration.
Returns:
dataloader_iter (DataLoader): a DataLoader for data iteration.
"""
return DataLoader(dataset, batch_size=batch_size)