libokj's picture
Upload 110 files
c0ec7e6
raw
history blame
3.99 kB
"""
DeepScreen package initialization, registering custom objects and monkey patching for some libraries.
"""
import sys
from builtins import eval
import lightning.fabric.strategies.launchers.subprocess_script as subprocess_script
import torch
from omegaconf import OmegaConf
from deepscreen.utils import get_logger
log = get_logger(__name__)
# Allow basic Python operations in hydra interpolation; examples:
# `in_channels: ${eval:${model.drug_encoder.out_channels}+${model.protein_encoder.out_channels}}`
# `subdir: ${eval:${hydra.job.override_dirname}.replace('/', '.')}`
OmegaConf.register_new_resolver("eval", eval)
def sanitize_path(path_str: str):
"""
Sanitize a string for path creation by replacing unsafe characters and cutting length to 255 (OS limitation).
"""
return path_str.replace("/", ".").replace("\\", ".").replace(":", "-")[:255]
OmegaConf.register_new_resolver("sanitize_path", sanitize_path)
def _hydra_subprocess_cmd(local_rank: int):
"""
Monkey patching for lightning.fabric.strategies.launchers.subprocess_script._hydra_subprocess_cmd
Temporarily fixes the problem of unnecessarily creating log folders for DDP subprocesses in Hydra multirun/sweep.
"""
import __main__ # local import to avoid https://github.com/Lightning-AI/lightning/issues/15218
from hydra.core.hydra_config import HydraConfig
from hydra.utils import get_original_cwd, to_absolute_path
# when user is using hydra find the absolute path
if __main__.__spec__ is None: # pragma: no-cover
command = [sys.executable, to_absolute_path(sys.argv[0])]
else:
command = [sys.executable, "-m", __main__.__spec__.name]
command += sys.argv[1:]
cwd = get_original_cwd()
rundir = f'"{HydraConfig.get().runtime.output_dir}"'
# Set output_subdir null since we don't want different subprocesses trying to write to config.yaml
command += [f"hydra.job.name=train_ddp_process_{local_rank}",
"hydra.output_subdir=null,"
f"hydra.runtime.output_dir={rundir}"]
return command, cwd
subprocess_script._hydra_subprocess_cmd = _hydra_subprocess_cmd
# from torch import Tensor
# from lightning.fabric.utilities.distributed import _distributed_available
# from lightning.pytorch.utilities.rank_zero import WarningCache
# from lightning.pytorch.utilities.warnings import PossibleUserWarning
# from lightning.pytorch.trainer.connectors.logger_connector.result import _ResultCollection
# warning_cache = WarningCache()
#
# @staticmethod
# def _get_cache(result_metric, on_step: bool):
# cache = None
# if on_step and result_metric.meta.on_step:
# cache = result_metric._forward_cache
# elif not on_step and result_metric.meta.on_epoch:
# if result_metric._computed is None:
# should = result_metric.meta.sync.should
# if not should and _distributed_available() and result_metric.is_tensor:
# warning_cache.warn(
# f"It is recommended to use `self.log({result_metric.meta.name!r}, ..., sync_dist=True)`"
# " when logging on epoch level in distributed setting to accumulate the metric across"
# " devices.",
# category=PossibleUserWarning,
# )
# result_metric.compute()
# result_metric.meta.sync.should = should
#
# cache = result_metric._computed
#
# if cache is not None:
# if isinstance(cache, Tensor):
# if not result_metric.meta.enable_graph:
# return cache.detach()
#
# return cache
#
#
# _ResultCollection._get_cache = _get_cache
if torch.cuda.is_available():
if torch.cuda.get_device_capability() >= (8, 0):
torch.set_float32_matmul_precision("high")
log.info("Your GPU supports tensor cores, "
"we will enable it automatically by setting `torch.set_float32_matmul_precision('high')`")