Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
File size: 6,768 Bytes
c0ec7e6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 |
# import time
# from pathlib import Path
# from typing import Any, Dict, List
#
# import hydra
# from pytorch_lightning import Callback
# from pytorch_lightning.loggers import Logger
# from pytorch_lightning.utilities import rank_zero_only
import warnings
from importlib.util import find_spec
from typing import Callable
from omegaconf import DictConfig
from deepscreen.utils import get_logger, enforce_tags, print_config_tree
log = get_logger(__name__)
def extras(cfg: DictConfig) -> None:
"""Applies optional utilities before a job is started.
Utilities:
- Ignoring python warnings
- Setting tags from command line
- Rich config printing
"""
# return if no `extras` config
if not cfg.get("extras"):
log.warning("Extras config not found! <cfg.extras=null>")
return
# disable python warnings
if cfg.extras.get("ignore_warnings"):
log.info("Disabling python warnings! <cfg.extras.ignore_warnings=True>")
warnings.filterwarnings("ignore")
# prompt user to input tags from command line if none are provided in the config
if cfg.extras.get("enforce_tags"):
log.info("Enforcing tags! <cfg.extras.enforce_tags=True>")
enforce_tags(cfg, save_to_file=True)
# pretty print config tree using Rich library
if cfg.extras.get("print_config"):
log.info("Printing config tree with Rich! <cfg.extras.print_config=True>")
print_config_tree(cfg, resolve=True, save_to_file=True)
def job_wrapper(extra_utils: bool) -> Callable:
"""Optional decorator that controls the failure behavior and extra utilities when executing a job function.
This wrapper can be used to:
- make sure loggers are closed even if the job function raises an exception (prevents multirun failure)
- save the exception to a `.log` file
- mark the run as failed with a dedicated file in the `logs/` folder (so we can find and rerun it later)
- etc. (adjust depending on your needs)
Example:
```
@utils.job_wrapper(extra_utils)
def train(cfg: DictConfig) -> Tuple[dict, dict]:
.
return metric_dict, object_dict
```
"""
def decorator(job_func):
def wrapped_func(cfg: DictConfig):
# execute the job
try:
# apply extra utilities
if extra_utils:
extras(cfg)
metric_dict, object_dict = job_func(cfg=cfg)
# things to do if exception occurs
except Exception as ex:
# save exception to `.log` file
log.exception("")
# some hyperparameter combinations might be invalid or cause out-of-memory errors
# so when using hparam search plugins like Optuna, you might want to disable
# raising the below exception to avoid multirun failure
raise ex
# things to always do after either success or exception
finally:
# display output dir path in terminal
log.info(f"Output dir: {cfg.paths.output_dir}")
# always close wandb run (even if exception occurs so multirun won't fail)
if find_spec("wandb"): # check if wandb is installed
import wandb
if wandb.run:
log.info("Closing wandb!")
wandb.finish()
return metric_dict, object_dict
return wrapped_func
return decorator
# @rank_zero_only
# def save_file(path, content) -> None:
# """Save file in rank zero mode (only on one process in multi-GPU setup)."""
# with open(path, "w+") as file:
# file.write(content)
#
#
# def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]:
# """Instantiates callbacks from config."""
# callbacks: List[Callback] = []
#
# if not callbacks_cfg:
# log.warning("Callbacks config is empty.")
# return callbacks
#
# if not isinstance(callbacks_cfg, DictConfig):
# raise TypeError("Callbacks config must be a DictConfig!")
#
# for _, cb_conf in callbacks_cfg.items():
# if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf:
# log.info(f"Instantiating callback <{cb_conf._target_}>")
# callbacks.append(hydra.utils.instantiate(cb_conf))
#
# return callbacks
#
#
# def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]:
# """Instantiates loggers from config."""
# logger: List[Logger] = []
#
# if not logger_cfg:
# log.warning("Logger config is empty.")
# return logger
#
# if not isinstance(logger_cfg, DictConfig):
# raise TypeError("Logger config must be a DictConfig!")
#
# for _, lg_conf in logger_cfg.items():
# if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf:
# log.info(f"Instantiating logger <{lg_conf._target_}>")
# logger.append(hydra.utils.instantiate(lg_conf))
#
# return logger
#
#
# @rank_zero_only
# def log_hyperparameters(object_dict: Dict[str, Any]) -> None:
# """Controls which config parts are saved by lightning loggers.
#
# Additionally saves:
# - Number of model parameters
# """
#
# hparams = {}
#
# cfg = object_dict["cfg"]
# model = object_dict["model"]
# trainer = object_dict["trainer"]
#
# if not trainer.logger:
# log.warning("Logger not found! Skipping hyperparameter logging.")
# return
#
# hparams["model"] = cfg["model"]
#
# # TODO Accommodation for LazyModule
# # save number of model parameters
# hparams["model/params/total"] = sum(p.numel() for p in model.parameters())
# hparams["model/params/trainable"] = sum(
# p.numel() for p in model.parameters() if p.requires_grad
# )
# hparams["model/params/non_trainable"] = sum(
# p.numel() for p in model.parameters() if not p.requires_grad
# )
#
# hparams["data"] = cfg["data"]
# hparams["trainer"] = cfg["trainer"]
#
# hparams["callbacks"] = cfg.get("callbacks")
# hparams["extras"] = cfg.get("extras")
#
# hparams["job_name"] = cfg.get("job_name")
# hparams["tags"] = cfg.get("tags")
# hparams["ckpt_path"] = cfg.get("ckpt_path")
# hparams["seed"] = cfg.get("seed")
#
# # send hparams to all loggers
# trainer.logger.log_hyperparams(hparams)
# def close_loggers() -> None:
# """Makes sure all loggers closed properly (prevents logging failure during multirun)."""
#
# log.info("Closing loggers.")
#
# if find_spec("wandb"): # if wandb is installed
# import wandb
#
# if wandb.run:
# log.info("Closing wandb!")
# wandb.finish()
|