|
import warnings
|
|
from importlib.util import find_spec
|
|
from typing import Callable
|
|
|
|
from omegaconf import DictConfig
|
|
|
|
from .logger import RankedLogger
|
|
from .rich_utils import enforce_tags, print_config_tree
|
|
|
|
log = RankedLogger(__name__, rank_zero_only=True)
|
|
|
|
|
|
def extras(cfg: DictConfig) -> None:
|
|
"""Applies optional utilities before the task is started.
|
|
|
|
Utilities:
|
|
- Ignoring python warnings
|
|
- Setting tags from command line
|
|
- Rich config printing
|
|
"""
|
|
|
|
|
|
if not cfg.get("extras"):
|
|
log.warning("Extras config not found! <cfg.extras=null>")
|
|
return
|
|
|
|
|
|
if cfg.extras.get("ignore_warnings"):
|
|
log.info("Disabling python warnings! <cfg.extras.ignore_warnings=True>")
|
|
warnings.filterwarnings("ignore")
|
|
|
|
|
|
if cfg.extras.get("enforce_tags"):
|
|
log.info("Enforcing tags! <cfg.extras.enforce_tags=True>")
|
|
enforce_tags(cfg, save_to_file=True)
|
|
|
|
|
|
if cfg.extras.get("print_config"):
|
|
log.info("Printing config tree with Rich! <cfg.extras.print_config=True>")
|
|
print_config_tree(cfg, resolve=True, save_to_file=True)
|
|
|
|
|
|
def task_wrapper(task_func: Callable) -> Callable:
|
|
"""Optional decorator that controls the failure behavior when executing the task function.
|
|
|
|
This wrapper can be used to:
|
|
- make sure loggers are closed even if the task function raises an exception (prevents multirun failure)
|
|
- save the exception to a `.log` file
|
|
- mark the run as failed with a dedicated file in the `logs/` folder (so we can find and rerun it later)
|
|
- etc. (adjust depending on your needs)
|
|
|
|
Example:
|
|
```
|
|
@utils.task_wrapper
|
|
def train(cfg: DictConfig) -> Tuple[dict, dict]:
|
|
|
|
...
|
|
|
|
return metric_dict, object_dict
|
|
```
|
|
"""
|
|
|
|
def wrap(cfg: DictConfig):
|
|
|
|
try:
|
|
metric_dict, object_dict = task_func(cfg=cfg)
|
|
|
|
|
|
except Exception as ex:
|
|
|
|
log.exception("")
|
|
|
|
|
|
|
|
|
|
|
|
raise ex
|
|
|
|
|
|
finally:
|
|
|
|
log.info(f"Output dir: {cfg.paths.run_dir}")
|
|
|
|
|
|
if find_spec("wandb"):
|
|
import wandb
|
|
|
|
if wandb.run:
|
|
log.info("Closing wandb!")
|
|
wandb.finish()
|
|
|
|
return metric_dict, object_dict
|
|
|
|
return wrap
|
|
|
|
|
|
def get_metric_value(metric_dict: dict, metric_name: str) -> float:
|
|
"""Safely retrieves value of the metric logged in LightningModule."""
|
|
|
|
if not metric_name:
|
|
log.info("Metric name is None! Skipping metric value retrieval...")
|
|
return None
|
|
|
|
if metric_name not in metric_dict:
|
|
raise Exception(
|
|
f"Metric value not found! <metric_name={metric_name}>\n"
|
|
"Make sure metric name logged in LightningModule is correct!\n"
|
|
"Make sure `optimized_metric` name in `hparams_search` config is correct!"
|
|
)
|
|
|
|
metric_value = metric_dict[metric_name].item()
|
|
log.info(f"Retrieved metric value! <{metric_name}={metric_value}>")
|
|
|
|
return metric_value
|
|
|