File size: 3,992 Bytes
c0ec7e6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
"""
DeepScreen package initialization, registering custom objects and monkey patching for some libraries.
"""
import sys
from builtins import eval

import lightning.fabric.strategies.launchers.subprocess_script as subprocess_script
import torch
from omegaconf import OmegaConf

from deepscreen.utils import get_logger

log = get_logger(__name__)

# Allow basic Python operations in hydra interpolation; examples:
# `in_channels: ${eval:${model.drug_encoder.out_channels}+${model.protein_encoder.out_channels}}`
# `subdir: ${eval:${hydra.job.override_dirname}.replace('/', '.')}`
OmegaConf.register_new_resolver("eval", eval)


def sanitize_path(path_str: str):
    """
    Sanitize a string for path creation by replacing unsafe characters and cutting length to 255 (OS limitation).
    """
    return path_str.replace("/", ".").replace("\\", ".").replace(":", "-")[:255]


OmegaConf.register_new_resolver("sanitize_path", sanitize_path)


def _hydra_subprocess_cmd(local_rank: int):
    """
    Monkey patching for lightning.fabric.strategies.launchers.subprocess_script._hydra_subprocess_cmd
    Temporarily fixes the problem of unnecessarily creating log folders for DDP subprocesses in Hydra multirun/sweep.
    """
    import __main__  # local import to avoid https://github.com/Lightning-AI/lightning/issues/15218
    from hydra.core.hydra_config import HydraConfig
    from hydra.utils import get_original_cwd, to_absolute_path

    # when user is using hydra find the absolute path
    if __main__.__spec__ is None:  # pragma: no-cover
        command = [sys.executable, to_absolute_path(sys.argv[0])]
    else:
        command = [sys.executable, "-m", __main__.__spec__.name]

    command += sys.argv[1:]

    cwd = get_original_cwd()
    rundir = f'"{HydraConfig.get().runtime.output_dir}"'
    # Set output_subdir null since we don't want different subprocesses trying to write to config.yaml
    command += [f"hydra.job.name=train_ddp_process_{local_rank}",
                "hydra.output_subdir=null,"
                f"hydra.runtime.output_dir={rundir}"]
    return command, cwd


subprocess_script._hydra_subprocess_cmd = _hydra_subprocess_cmd

# from torch import Tensor
# from lightning.fabric.utilities.distributed import _distributed_available
# from lightning.pytorch.utilities.rank_zero import WarningCache
# from lightning.pytorch.utilities.warnings import PossibleUserWarning
# from lightning.pytorch.trainer.connectors.logger_connector.result import _ResultCollection

# warning_cache = WarningCache()
#
# @staticmethod
# def _get_cache(result_metric, on_step: bool):
#     cache = None
#     if on_step and result_metric.meta.on_step:
#         cache = result_metric._forward_cache
#     elif not on_step and result_metric.meta.on_epoch:
#         if result_metric._computed is None:
#             should = result_metric.meta.sync.should
#             if not should and _distributed_available() and result_metric.is_tensor:
#                 warning_cache.warn(
#                     f"It is recommended to use `self.log({result_metric.meta.name!r}, ..., sync_dist=True)`"
#                     " when logging on epoch level in distributed setting to accumulate the metric across"
#                     " devices.",
#                     category=PossibleUserWarning,
#                 )
#             result_metric.compute()
#             result_metric.meta.sync.should = should
#
#         cache = result_metric._computed
#
#         if cache is not None:
#             if isinstance(cache, Tensor):
#                 if not result_metric.meta.enable_graph:
#                     return cache.detach()
#
#     return cache
#
#
# _ResultCollection._get_cache = _get_cache

if torch.cuda.is_available():
    if torch.cuda.get_device_capability() >= (8, 0):
        torch.set_float32_matmul_precision("high")
        log.info("Your GPU supports tensor cores, "
                 "we will enable it automatically by setting `torch.set_float32_matmul_precision('high')`")