from torch.utils.data import WeightedRandomSampler import torch import numpy as np class CustomWeightedRandomSampler(WeightedRandomSampler): """WeightedRandomSampler except allows for more than 2^24 samples to be sampled""" def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def __iter__(self): rand_tensor = np.random.choice(range(0, len(self.weights)), size=self.num_samples, p=self.weights.numpy() / torch.sum(self.weights).numpy(), replace=self.replacement) rand_tensor = torch.from_numpy(rand_tensor) return iter(rand_tensor.tolist())