from .braceexpand import braceexpand from .context import autocast_exclude_mps from .file import get_latest_checkpoint from .instantiators import instantiate_callbacks, instantiate_loggers from .logger import RankedLogger from .logging_utils import log_hyperparameters from .rich_utils import enforce_tags, print_config_tree from .utils import extras, get_metric_value, task_wrapper def set_seed(seed: int): if seed < 0: seed = -seed if seed > (1 << 31): seed = 1 << 31 random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) if torch.backends.cudnn.is_available(): torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False __all__ = [ "enforce_tags", "extras", "get_metric_value", "RankedLogger", "instantiate_callbacks", "instantiate_loggers", "log_hyperparameters", "print_config_tree", "task_wrapper", "braceexpand", "get_latest_checkpoint", "autocast_exclude_mps", ]