Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
from typing import Mapping, Iterable | |
from torch.utils.data import BatchSampler, RandomSampler, SequentialSampler | |
class SafeBatchSampler(BatchSampler): | |
""" | |
A safe `batch_sampler` that skips samples with `None` values, supports shuffling, and keep a fixed batch size. | |
Args: | |
data_source (Dataset): The dataset to sample from. | |
batch_size (int): The size of each batch. | |
drop_last (bool): Whether to drop the last batch if its size is smaller than `batch_size`. Defaults to `False`. | |
shuffle (bool, optional): Whether to shuffle the data before sampling. Defaults to `True`. | |
Example: | |
>>> dataloader = DataLoader(dataset, batch_sampler=SafeBatchSampler(dataset, batch_size, drop_last, shuffle)) | |
""" | |
def __init__(self, data_source, batch_size: int, drop_last: bool, shuffle: bool, sampler=None): | |
if not isinstance(batch_size, int) or isinstance(batch_size, bool) or \ | |
batch_size <= 0: | |
raise ValueError(f"batch_size should be a positive integer value, but got batch_size={batch_size}") | |
if not isinstance(drop_last, bool): | |
raise ValueError(f"drop_last should be a boolean value, but got drop_last={drop_last}") | |
if sampler: | |
pass | |
elif shuffle: | |
sampler = RandomSampler(data_source) # type: ignore[arg-type] | |
else: | |
sampler = SequentialSampler(data_source) # type: ignore[arg-type] | |
super().__init__(sampler, batch_size, drop_last) | |
self.data_source = data_source | |
# def __iter__(self): | |
# batch = [] | |
# for idx in self.sampler: | |
# sample = self.data_source[idx] | |
# # if isinstance(sample, list | tuple): | |
# # pass | |
# # elif isinstance(sample, dict): | |
# # sample = sample.values() | |
# # elif isinstance(sample, Series): | |
# # sample = sample.values | |
# # else: | |
# # sample = [sample] | |
# if isinstance(sample, (Iterable, Mapping)) and not isinstance(sample, str): | |
# if isinstance(sample, Mapping): | |
# sample = list(sample.values()) | |
# else: | |
# sample = [sample] | |
# | |
# if all(v is not None for v in sample): | |
# batch.append(idx) | |
# if len(batch) == self.batch_size: | |
# yield batch | |
# batch = [] | |
# | |
# if len(batch) > 0 and not self.drop_last: | |
# yield batch | |
# | |
# if not batch: | |
# raise StopIteration | |
def __iter__(self): | |
batch = [0] * self.batch_size | |
idx_in_batch = 0 | |
for idx in self.sampler: | |
sample = self.data_source[idx] | |
if isinstance(sample, (Iterable, Mapping)) and not isinstance(sample, str): | |
if isinstance(sample, Mapping): | |
sample = sample.values() | |
else: | |
sample = [sample] | |
if all(v is not None for v in sample): | |
batch[idx_in_batch] = idx | |
idx_in_batch += 1 | |
if idx_in_batch == self.batch_size: | |
yield batch | |
idx_in_batch = 0 | |
batch = [0] * self.batch_size | |
if idx_in_batch > 0 and not self.drop_last: | |
yield batch[:idx_in_batch] | |
# if not any(batch): | |
# raise StopIteration | |
# return | |
def __len__(self): | |
float("inf") | |