File size: 1,122 Bytes
048ff41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
from .braceexpand import braceexpand
from .context import autocast_exclude_mps
from .file import get_latest_checkpoint
from .instantiators import instantiate_callbacks, instantiate_loggers
from .logger import RankedLogger
from .logging_utils import log_hyperparameters
from .rich_utils import enforce_tags, print_config_tree
from .utils import extras, get_metric_value, task_wrapper

def set_seed(seed: int):
    if seed < 0:
        seed = -seed
    if seed > (1 << 31):
        seed = 1 << 31

    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)

    if torch.backends.cudnn.is_available():
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
__all__ = [
    "enforce_tags",
    "extras",
    "get_metric_value",
    "RankedLogger",
    "instantiate_callbacks",
    "instantiate_loggers",
    "log_hyperparameters",
    "print_config_tree",
    "task_wrapper",
    "braceexpand",
    "get_latest_checkpoint",
    "autocast_exclude_mps",
]