File size: 1,821 Bytes
d758c99 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 |
import random
import sys
import numpy as np
import torch
class RandomState:
def __init__(self):
self.random_mod_state = random.getstate()
self.np_state = np.random.get_state()
self.torch_cpu_state = torch.get_rng_state()
self.torch_gpu_states = [
torch.cuda.get_rng_state(d)
for d in range(torch.cuda.device_count())
]
def restore(self):
random.setstate(self.random_mod_state)
np.random.set_state(self.np_state)
torch.set_rng_state(self.torch_cpu_state)
for d, state in enumerate(self.torch_gpu_states):
torch.cuda.set_rng_state(state, d)
class RandomContext:
'''Save and restore state of PyTorch, NumPy, Python RNGs.'''
def __init__(self, seed=None):
outside_state = RandomState()
random.seed(seed)
np.random.seed(seed)
if seed is None:
torch.manual_seed(random.randint(-sys.maxsize - 1, sys.maxsize))
else:
torch.manual_seed(seed)
# torch.cuda.manual_seed_all is called by torch.manual_seed
self.inside_state = RandomState()
outside_state.restore()
self._active = False
def __enter__(self):
if self._active:
raise Exception('RandomContext can be active only once')
# Save current state of RNG
self.outside_state = RandomState()
# Restore saved state of RNG for this context
self.inside_state.restore()
self._active = True
def __exit__(self, exception_type, exception_value, traceback):
# Save current state of RNG
self.inside_state = RandomState()
# Restore state of RNG saved in __enter__
self.outside_state.restore()
self.outside_state = None
self._active = False
|