|
import numpy as np |
|
import random |
|
import torch |
|
|
|
|
|
def set_random_seed(seed: int): |
|
torch.manual_seed((seed) % (1 << 31)) |
|
torch.cuda.manual_seed((seed) % (1 << 31)) |
|
torch.cuda.manual_seed_all((seed) % (1 << 31)) |
|
np.random.seed((seed) % (1 << 31)) |
|
random.seed((seed) % (1 << 31)) |
|
torch.backends.cudnn.benchmark = False |
|
torch.backends.cudnn.deterministic = True |
|
|
|
|
|
class StackedRandomGenerator: |
|
""" |
|
Wrapper for torch.Generator that allows specifying a different random seed for each |
|
sample in a minibatch. |
|
""" |
|
|
|
def __init__(self, device, seeds): |
|
super().__init__() |
|
self.generators = [ |
|
torch.Generator(device).manual_seed(int(seed) % (1 << 31)) for seed in seeds |
|
] |
|
|
|
def randn_rn(self, size, **kwargs): |
|
assert size[0] == len(self.generators) |
|
return torch.stack( |
|
[torch.randn(size[1:], generator=gen, **kwargs) for gen in self.generators] |
|
) |
|
|
|
def randn_like(self, input): |
|
return self.randn_rn( |
|
input.shape, dtype=input.dtype, layout=input.layout, device=input.device |
|
) |
|
|
|
def randint(self, *args, size, **kwargs): |
|
assert size[0] == len(self.generators) |
|
return torch.stack( |
|
[ |
|
torch.randint(*args, size=size[1:], generator=gen, **kwargs) |
|
for gen in self.generators |
|
] |
|
) |
|
|