from typing import Mapping, Iterable from torch.utils.data import BatchSampler, RandomSampler, SequentialSampler class SafeBatchSampler(BatchSampler): """ A safe `batch_sampler` that skips samples with `None` values, supports shuffling, and keep a fixed batch size. Args: data_source (Dataset): The dataset to sample from. batch_size (int): The size of each batch. drop_last (bool): Whether to drop the last batch if its size is smaller than `batch_size`. Defaults to `False`. shuffle (bool, optional): Whether to shuffle the data before sampling. Defaults to `True`. Example: >>> dataloader = DataLoader(dataset, batch_sampler=SafeBatchSampler(dataset, batch_size, drop_last, shuffle)) """ def __init__(self, data_source, batch_size: int, drop_last: bool, shuffle: bool, sampler=None): if not isinstance(batch_size, int) or isinstance(batch_size, bool) or \ batch_size <= 0: raise ValueError(f"batch_size should be a positive integer value, but got batch_size={batch_size}") if not isinstance(drop_last, bool): raise ValueError(f"drop_last should be a boolean value, but got drop_last={drop_last}") if sampler: pass elif shuffle: sampler = RandomSampler(data_source) # type: ignore[arg-type] else: sampler = SequentialSampler(data_source) # type: ignore[arg-type] super().__init__(sampler, batch_size, drop_last) self.data_source = data_source # def __iter__(self): # batch = [] # for idx in self.sampler: # sample = self.data_source[idx] # # if isinstance(sample, list | tuple): # # pass # # elif isinstance(sample, dict): # # sample = sample.values() # # elif isinstance(sample, Series): # # sample = sample.values # # else: # # sample = [sample] # if isinstance(sample, (Iterable, Mapping)) and not isinstance(sample, str): # if isinstance(sample, Mapping): # sample = list(sample.values()) # else: # sample = [sample] # # if all(v is not None for v in sample): # batch.append(idx) # if len(batch) == self.batch_size: # yield batch # batch = [] # # if len(batch) > 0 and not self.drop_last: # yield batch # # if not batch: # raise StopIteration def __iter__(self): batch = [0] * self.batch_size idx_in_batch = 0 for idx in self.sampler: sample = self.data_source[idx] if isinstance(sample, (Iterable, Mapping)) and not isinstance(sample, str): if isinstance(sample, Mapping): sample = sample.values() else: sample = [sample] if all(v is not None for v in sample): batch[idx_in_batch] = idx idx_in_batch += 1 if idx_in_batch == self.batch_size: yield batch idx_in_batch = 0 batch = [0] * self.batch_size if idx_in_batch > 0 and not self.drop_last: yield batch[:idx_in_batch] # if not any(batch): # raise StopIteration # return def __len__(self): float("inf")