|
from torch.utils.data import WeightedRandomSampler |
|
import torch |
|
import numpy as np |
|
class CustomWeightedRandomSampler(WeightedRandomSampler): |
|
"""WeightedRandomSampler except allows for more than 2^24 samples to be sampled""" |
|
def __init__(self, *args, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
|
|
def __iter__(self): |
|
rand_tensor = np.random.choice(range(0, len(self.weights)), |
|
size=self.num_samples, |
|
p=self.weights.numpy() / torch.sum(self.weights).numpy(), |
|
replace=self.replacement) |
|
rand_tensor = torch.from_numpy(rand_tensor) |
|
return iter(rand_tensor.tolist()) |
|
|