from .braceexpand import braceexpand | |
from .context import autocast_exclude_mps | |
from .file import get_latest_checkpoint | |
from .instantiators import instantiate_callbacks, instantiate_loggers | |
from .logger import RankedLogger | |
from .logging_utils import log_hyperparameters | |
from .rich_utils import enforce_tags, print_config_tree | |
from .utils import extras, get_metric_value, task_wrapper | |
def set_seed(seed: int): | |
if seed < 0: | |
seed = -seed | |
if seed > (1 << 31): | |
seed = 1 << 31 | |
random.seed(seed) | |
np.random.seed(seed) | |
torch.manual_seed(seed) | |
if torch.cuda.is_available(): | |
torch.cuda.manual_seed(seed) | |
torch.cuda.manual_seed_all(seed) | |
if torch.backends.cudnn.is_available(): | |
torch.backends.cudnn.deterministic = True | |
torch.backends.cudnn.benchmark = False | |
__all__ = [ | |
"enforce_tags", | |
"extras", | |
"get_metric_value", | |
"RankedLogger", | |
"instantiate_callbacks", | |
"instantiate_loggers", | |
"log_hyperparameters", | |
"print_config_tree", | |
"task_wrapper", | |
"braceexpand", | |
"get_latest_checkpoint", | |
"autocast_exclude_mps", | |
] | |