File size: 1,122 Bytes
048ff41 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 |
from .braceexpand import braceexpand
from .context import autocast_exclude_mps
from .file import get_latest_checkpoint
from .instantiators import instantiate_callbacks, instantiate_loggers
from .logger import RankedLogger
from .logging_utils import log_hyperparameters
from .rich_utils import enforce_tags, print_config_tree
from .utils import extras, get_metric_value, task_wrapper
def set_seed(seed: int):
if seed < 0:
seed = -seed
if seed > (1 << 31):
seed = 1 << 31
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
if torch.backends.cudnn.is_available():
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
__all__ = [
"enforce_tags",
"extras",
"get_metric_value",
"RankedLogger",
"instantiate_callbacks",
"instantiate_loggers",
"log_hyperparameters",
"print_config_tree",
"task_wrapper",
"braceexpand",
"get_latest_checkpoint",
"autocast_exclude_mps",
]
|