import torch import numpy as np class AliasMethod(object): ''' From: https://hips.seas.harvard.edu/blog/2013/03/03/the-alias-method-efficient-sampling-with-many-discrete-outcomes/ ''' def __init__(self, probs): if probs.sum() > 1: probs.div_(probs.sum()) K = len(probs) self.prob = torch.zeros(K) self.alias = torch.LongTensor([0]*K) # Sort the data into the outcomes with probabilities # that are larger and smaller than 1/K. smaller = [] larger = [] for kk, prob in enumerate(probs): self.prob[kk] = K*prob if self.prob[kk] < 1.0: smaller.append(kk) else: larger.append(kk) # Loop though and create little binary mixtures that # appropriately allocate the larger outcomes over the # overall uniform mixture. while len(smaller) > 0 and len(larger) > 0: small = smaller.pop() large = larger.pop() self.alias[small] = large self.prob[large] = (self.prob[large] - 1.0) + self.prob[small] if self.prob[large] < 1.0: smaller.append(large) else: larger.append(large) for last_one in smaller+larger: self.prob[last_one] = 1 def cuda(self): self.prob = self.prob.cuda() self.alias = self.alias.cuda() def draw(self, N): ''' Draw N samples from multinomial ''' K = self.alias.size(0) kk = torch.zeros(N, dtype=torch.long, device=self.prob.device).random_(0, K) prob = self.prob.index_select(0, kk) alias = self.alias.index_select(0, kk) # b is whether a random number is greater than q b = torch.bernoulli(prob) oq = kk.mul(b.long()) oj = alias.mul((1-b).long()) return oq + oj