libokj's picture
Upload 110 files
history blame
6.77 kB
# import time
# from pathlib import Path
# from typing import Any, Dict, List
# import hydra
# from pytorch_lightning import Callback
# from pytorch_lightning.loggers import Logger
# from pytorch_lightning.utilities import rank_zero_only
import warnings
from importlib.util import find_spec
from typing import Callable
from omegaconf import DictConfig
from deepscreen.utils import get_logger, enforce_tags, print_config_tree
log = get_logger(__name__)
def extras(cfg: DictConfig) -> None:
"""Applies optional utilities before a job is started.
- Ignoring python warnings
- Setting tags from command line
- Rich config printing
# return if no `extras` config
if not cfg.get("extras"):
log.warning("Extras config not found! <cfg.extras=null>")
# disable python warnings
if cfg.extras.get("ignore_warnings"):"Disabling python warnings! <cfg.extras.ignore_warnings=True>")
# prompt user to input tags from command line if none are provided in the config
if cfg.extras.get("enforce_tags"):"Enforcing tags! <cfg.extras.enforce_tags=True>")
enforce_tags(cfg, save_to_file=True)
# pretty print config tree using Rich library
if cfg.extras.get("print_config"):"Printing config tree with Rich! <cfg.extras.print_config=True>")
print_config_tree(cfg, resolve=True, save_to_file=True)
def job_wrapper(extra_utils: bool) -> Callable:
"""Optional decorator that controls the failure behavior and extra utilities when executing a job function.
This wrapper can be used to:
- make sure loggers are closed even if the job function raises an exception (prevents multirun failure)
- save the exception to a `.log` file
- mark the run as failed with a dedicated file in the `logs/` folder (so we can find and rerun it later)
- etc. (adjust depending on your needs)
def train(cfg: DictConfig) -> Tuple[dict, dict]:
return metric_dict, object_dict
def decorator(job_func):
def wrapped_func(cfg: DictConfig):
# execute the job
# apply extra utilities
if extra_utils:
metric_dict, object_dict = job_func(cfg=cfg)
# things to do if exception occurs
except Exception as ex:
# save exception to `.log` file
# some hyperparameter combinations might be invalid or cause out-of-memory errors
# so when using hparam search plugins like Optuna, you might want to disable
# raising the below exception to avoid multirun failure
raise ex
# things to always do after either success or exception
# display output dir path in terminal"Output dir: {cfg.paths.output_dir}")
# always close wandb run (even if exception occurs so multirun won't fail)
if find_spec("wandb"): # check if wandb is installed
import wandb
if"Closing wandb!")
return metric_dict, object_dict
return wrapped_func
return decorator
# @rank_zero_only
# def save_file(path, content) -> None:
# """Save file in rank zero mode (only on one process in multi-GPU setup)."""
# with open(path, "w+") as file:
# file.write(content)
# def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]:
# """Instantiates callbacks from config."""
# callbacks: List[Callback] = []
# if not callbacks_cfg:
# log.warning("Callbacks config is empty.")
# return callbacks
# if not isinstance(callbacks_cfg, DictConfig):
# raise TypeError("Callbacks config must be a DictConfig!")
# for _, cb_conf in callbacks_cfg.items():
# if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf:
#"Instantiating callback <{cb_conf._target_}>")
# callbacks.append(hydra.utils.instantiate(cb_conf))
# return callbacks
# def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]:
# """Instantiates loggers from config."""
# logger: List[Logger] = []
# if not logger_cfg:
# log.warning("Logger config is empty.")
# return logger
# if not isinstance(logger_cfg, DictConfig):
# raise TypeError("Logger config must be a DictConfig!")
# for _, lg_conf in logger_cfg.items():
# if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf:
#"Instantiating logger <{lg_conf._target_}>")
# logger.append(hydra.utils.instantiate(lg_conf))
# return logger
# @rank_zero_only
# def log_hyperparameters(object_dict: Dict[str, Any]) -> None:
# """Controls which config parts are saved by lightning loggers.
# Additionally saves:
# - Number of model parameters
# """
# hparams = {}
# cfg = object_dict["cfg"]
# model = object_dict["model"]
# trainer = object_dict["trainer"]
# if not trainer.logger:
# log.warning("Logger not found! Skipping hyperparameter logging.")
# return
# hparams["model"] = cfg["model"]
# # TODO Accommodation for LazyModule
# # save number of model parameters
# hparams["model/params/total"] = sum(p.numel() for p in model.parameters())
# hparams["model/params/trainable"] = sum(
# p.numel() for p in model.parameters() if p.requires_grad
# )
# hparams["model/params/non_trainable"] = sum(
# p.numel() for p in model.parameters() if not p.requires_grad
# )
# hparams["data"] = cfg["data"]
# hparams["trainer"] = cfg["trainer"]
# hparams["callbacks"] = cfg.get("callbacks")
# hparams["extras"] = cfg.get("extras")
# hparams["job_name"] = cfg.get("job_name")
# hparams["tags"] = cfg.get("tags")
# hparams["ckpt_path"] = cfg.get("ckpt_path")
# hparams["seed"] = cfg.get("seed")
# # send hparams to all loggers
# trainer.logger.log_hyperparams(hparams)
# def close_loggers() -> None:
# """Makes sure all loggers closed properly (prevents logging failure during multirun)."""
#"Closing loggers.")
# if find_spec("wandb"): # if wandb is installed
# import wandb
# if
#"Closing wandb!")
# wandb.finish()