Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
# import time | |
# from pathlib import Path | |
# from typing import Any, Dict, List | |
# | |
# import hydra | |
# from pytorch_lightning import Callback | |
# from pytorch_lightning.loggers import Logger | |
# from pytorch_lightning.utilities import rank_zero_only | |
import warnings | |
from importlib.util import find_spec | |
from typing import Callable | |
from omegaconf import DictConfig | |
from deepscreen.utils import get_logger, enforce_tags, print_config_tree | |
log = get_logger(__name__) | |
def extras(cfg: DictConfig) -> None: | |
"""Applies optional utilities before a job is started. | |
Utilities: | |
- Ignoring python warnings | |
- Setting tags from command line | |
- Rich config printing | |
""" | |
# return if no `extras` config | |
if not cfg.get("extras"): | |
log.warning("Extras config not found! <cfg.extras=null>") | |
return | |
# disable python warnings | |
if cfg.extras.get("ignore_warnings"): | |
log.info("Disabling python warnings! <cfg.extras.ignore_warnings=True>") | |
warnings.filterwarnings("ignore") | |
# prompt user to input tags from command line if none are provided in the config | |
if cfg.extras.get("enforce_tags"): | |
log.info("Enforcing tags! <cfg.extras.enforce_tags=True>") | |
enforce_tags(cfg, save_to_file=True) | |
# pretty print config tree using Rich library | |
if cfg.extras.get("print_config"): | |
log.info("Printing config tree with Rich! <cfg.extras.print_config=True>") | |
print_config_tree(cfg, resolve=True, save_to_file=True) | |
def job_wrapper(extra_utils: bool) -> Callable: | |
"""Optional decorator that controls the failure behavior and extra utilities when executing a job function. | |
This wrapper can be used to: | |
- make sure loggers are closed even if the job function raises an exception (prevents multirun failure) | |
- save the exception to a `.log` file | |
- mark the run as failed with a dedicated file in the `logs/` folder (so we can find and rerun it later) | |
- etc. (adjust depending on your needs) | |
Example: | |
``` | |
@utils.job_wrapper(extra_utils) | |
def train(cfg: DictConfig) -> Tuple[dict, dict]: | |
. | |
return metric_dict, object_dict | |
``` | |
""" | |
def decorator(job_func): | |
def wrapped_func(cfg: DictConfig): | |
# execute the job | |
try: | |
# apply extra utilities | |
if extra_utils: | |
extras(cfg) | |
metric_dict, object_dict = job_func(cfg=cfg) | |
# things to do if exception occurs | |
except Exception as ex: | |
# save exception to `.log` file | |
log.exception("") | |
# some hyperparameter combinations might be invalid or cause out-of-memory errors | |
# so when using hparam search plugins like Optuna, you might want to disable | |
# raising the below exception to avoid multirun failure | |
raise ex | |
# things to always do after either success or exception | |
finally: | |
# display output dir path in terminal | |
log.info(f"Output dir: {cfg.paths.output_dir}") | |
# always close wandb run (even if exception occurs so multirun won't fail) | |
if find_spec("wandb"): # check if wandb is installed | |
import wandb | |
if wandb.run: | |
log.info("Closing wandb!") | |
wandb.finish() | |
return metric_dict, object_dict | |
return wrapped_func | |
return decorator | |
# @rank_zero_only | |
# def save_file(path, content) -> None: | |
# """Save file in rank zero mode (only on one process in multi-GPU setup).""" | |
# with open(path, "w+") as file: | |
# file.write(content) | |
# | |
# | |
# def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]: | |
# """Instantiates callbacks from config.""" | |
# callbacks: List[Callback] = [] | |
# | |
# if not callbacks_cfg: | |
# log.warning("Callbacks config is empty.") | |
# return callbacks | |
# | |
# if not isinstance(callbacks_cfg, DictConfig): | |
# raise TypeError("Callbacks config must be a DictConfig!") | |
# | |
# for _, cb_conf in callbacks_cfg.items(): | |
# if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf: | |
# log.info(f"Instantiating callback <{cb_conf._target_}>") | |
# callbacks.append(hydra.utils.instantiate(cb_conf)) | |
# | |
# return callbacks | |
# | |
# | |
# def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]: | |
# """Instantiates loggers from config.""" | |
# logger: List[Logger] = [] | |
# | |
# if not logger_cfg: | |
# log.warning("Logger config is empty.") | |
# return logger | |
# | |
# if not isinstance(logger_cfg, DictConfig): | |
# raise TypeError("Logger config must be a DictConfig!") | |
# | |
# for _, lg_conf in logger_cfg.items(): | |
# if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf: | |
# log.info(f"Instantiating logger <{lg_conf._target_}>") | |
# logger.append(hydra.utils.instantiate(lg_conf)) | |
# | |
# return logger | |
# | |
# | |
# @rank_zero_only | |
# def log_hyperparameters(object_dict: Dict[str, Any]) -> None: | |
# """Controls which config parts are saved by lightning loggers. | |
# | |
# Additionally saves: | |
# - Number of model parameters | |
# """ | |
# | |
# hparams = {} | |
# | |
# cfg = object_dict["cfg"] | |
# model = object_dict["model"] | |
# trainer = object_dict["trainer"] | |
# | |
# if not trainer.logger: | |
# log.warning("Logger not found! Skipping hyperparameter logging.") | |
# return | |
# | |
# hparams["model"] = cfg["model"] | |
# | |
# # TODO Accommodation for LazyModule | |
# # save number of model parameters | |
# hparams["model/params/total"] = sum(p.numel() for p in model.parameters()) | |
# hparams["model/params/trainable"] = sum( | |
# p.numel() for p in model.parameters() if p.requires_grad | |
# ) | |
# hparams["model/params/non_trainable"] = sum( | |
# p.numel() for p in model.parameters() if not p.requires_grad | |
# ) | |
# | |
# hparams["data"] = cfg["data"] | |
# hparams["trainer"] = cfg["trainer"] | |
# | |
# hparams["callbacks"] = cfg.get("callbacks") | |
# hparams["extras"] = cfg.get("extras") | |
# | |
# hparams["job_name"] = cfg.get("job_name") | |
# hparams["tags"] = cfg.get("tags") | |
# hparams["ckpt_path"] = cfg.get("ckpt_path") | |
# hparams["seed"] = cfg.get("seed") | |
# | |
# # send hparams to all loggers | |
# trainer.logger.log_hyperparams(hparams) | |
# def close_loggers() -> None: | |
# """Makes sure all loggers closed properly (prevents logging failure during multirun).""" | |
# | |
# log.info("Closing loggers.") | |
# | |
# if find_spec("wandb"): # if wandb is installed | |
# import wandb | |
# | |
# if wandb.run: | |
# log.info("Closing wandb!") | |
# wandb.finish() | |