libokj's picture
Upload 110 files
c0ec7e6
raw
history blame
6.77 kB
# import time
# from pathlib import Path
# from typing import Any, Dict, List
#
# import hydra
# from pytorch_lightning import Callback
# from pytorch_lightning.loggers import Logger
# from pytorch_lightning.utilities import rank_zero_only
import warnings
from importlib.util import find_spec
from typing import Callable
from omegaconf import DictConfig
from deepscreen.utils import get_logger, enforce_tags, print_config_tree
log = get_logger(__name__)
def extras(cfg: DictConfig) -> None:
"""Applies optional utilities before a job is started.
Utilities:
- Ignoring python warnings
- Setting tags from command line
- Rich config printing
"""
# return if no `extras` config
if not cfg.get("extras"):
log.warning("Extras config not found! <cfg.extras=null>")
return
# disable python warnings
if cfg.extras.get("ignore_warnings"):
log.info("Disabling python warnings! <cfg.extras.ignore_warnings=True>")
warnings.filterwarnings("ignore")
# prompt user to input tags from command line if none are provided in the config
if cfg.extras.get("enforce_tags"):
log.info("Enforcing tags! <cfg.extras.enforce_tags=True>")
enforce_tags(cfg, save_to_file=True)
# pretty print config tree using Rich library
if cfg.extras.get("print_config"):
log.info("Printing config tree with Rich! <cfg.extras.print_config=True>")
print_config_tree(cfg, resolve=True, save_to_file=True)
def job_wrapper(extra_utils: bool) -> Callable:
"""Optional decorator that controls the failure behavior and extra utilities when executing a job function.
This wrapper can be used to:
- make sure loggers are closed even if the job function raises an exception (prevents multirun failure)
- save the exception to a `.log` file
- mark the run as failed with a dedicated file in the `logs/` folder (so we can find and rerun it later)
- etc. (adjust depending on your needs)
Example:
```
@utils.job_wrapper(extra_utils)
def train(cfg: DictConfig) -> Tuple[dict, dict]:
.
return metric_dict, object_dict
```
"""
def decorator(job_func):
def wrapped_func(cfg: DictConfig):
# execute the job
try:
# apply extra utilities
if extra_utils:
extras(cfg)
metric_dict, object_dict = job_func(cfg=cfg)
# things to do if exception occurs
except Exception as ex:
# save exception to `.log` file
log.exception("")
# some hyperparameter combinations might be invalid or cause out-of-memory errors
# so when using hparam search plugins like Optuna, you might want to disable
# raising the below exception to avoid multirun failure
raise ex
# things to always do after either success or exception
finally:
# display output dir path in terminal
log.info(f"Output dir: {cfg.paths.output_dir}")
# always close wandb run (even if exception occurs so multirun won't fail)
if find_spec("wandb"): # check if wandb is installed
import wandb
if wandb.run:
log.info("Closing wandb!")
wandb.finish()
return metric_dict, object_dict
return wrapped_func
return decorator
# @rank_zero_only
# def save_file(path, content) -> None:
# """Save file in rank zero mode (only on one process in multi-GPU setup)."""
# with open(path, "w+") as file:
# file.write(content)
#
#
# def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]:
# """Instantiates callbacks from config."""
# callbacks: List[Callback] = []
#
# if not callbacks_cfg:
# log.warning("Callbacks config is empty.")
# return callbacks
#
# if not isinstance(callbacks_cfg, DictConfig):
# raise TypeError("Callbacks config must be a DictConfig!")
#
# for _, cb_conf in callbacks_cfg.items():
# if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf:
# log.info(f"Instantiating callback <{cb_conf._target_}>")
# callbacks.append(hydra.utils.instantiate(cb_conf))
#
# return callbacks
#
#
# def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]:
# """Instantiates loggers from config."""
# logger: List[Logger] = []
#
# if not logger_cfg:
# log.warning("Logger config is empty.")
# return logger
#
# if not isinstance(logger_cfg, DictConfig):
# raise TypeError("Logger config must be a DictConfig!")
#
# for _, lg_conf in logger_cfg.items():
# if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf:
# log.info(f"Instantiating logger <{lg_conf._target_}>")
# logger.append(hydra.utils.instantiate(lg_conf))
#
# return logger
#
#
# @rank_zero_only
# def log_hyperparameters(object_dict: Dict[str, Any]) -> None:
# """Controls which config parts are saved by lightning loggers.
#
# Additionally saves:
# - Number of model parameters
# """
#
# hparams = {}
#
# cfg = object_dict["cfg"]
# model = object_dict["model"]
# trainer = object_dict["trainer"]
#
# if not trainer.logger:
# log.warning("Logger not found! Skipping hyperparameter logging.")
# return
#
# hparams["model"] = cfg["model"]
#
# # TODO Accommodation for LazyModule
# # save number of model parameters
# hparams["model/params/total"] = sum(p.numel() for p in model.parameters())
# hparams["model/params/trainable"] = sum(
# p.numel() for p in model.parameters() if p.requires_grad
# )
# hparams["model/params/non_trainable"] = sum(
# p.numel() for p in model.parameters() if not p.requires_grad
# )
#
# hparams["data"] = cfg["data"]
# hparams["trainer"] = cfg["trainer"]
#
# hparams["callbacks"] = cfg.get("callbacks")
# hparams["extras"] = cfg.get("extras")
#
# hparams["job_name"] = cfg.get("job_name")
# hparams["tags"] = cfg.get("tags")
# hparams["ckpt_path"] = cfg.get("ckpt_path")
# hparams["seed"] = cfg.get("seed")
#
# # send hparams to all loggers
# trainer.logger.log_hyperparams(hparams)
# def close_loggers() -> None:
# """Makes sure all loggers closed properly (prevents logging failure during multirun)."""
#
# log.info("Closing loggers.")
#
# if find_spec("wandb"): # if wandb is installed
# import wandb
#
# if wandb.run:
# log.info("Closing wandb!")
# wandb.finish()