Spaces:
Running
Running
File size: 4,283 Bytes
4f6613a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
import random
import warnings
from importlib.util import find_spec
from typing import Callable
import numpy as np
import torch
from omegaconf import DictConfig
from .logger import RankedLogger
from .rich_utils import enforce_tags, print_config_tree
log = RankedLogger(__name__, rank_zero_only=True)
def extras(cfg: DictConfig) -> None:
"""Applies optional utilities before the task is started.
Utilities:
- Ignoring python warnings
- Setting tags from command line
- Rich config printing
"""
# return if no `extras` config
if not cfg.get("extras"):
log.warning("Extras config not found! <cfg.extras=null>")
return
# disable python warnings
if cfg.extras.get("ignore_warnings"):
log.info("Disabling python warnings! <cfg.extras.ignore_warnings=True>")
warnings.filterwarnings("ignore")
# prompt user to input tags from command line if none are provided in the config
if cfg.extras.get("enforce_tags"):
log.info("Enforcing tags! <cfg.extras.enforce_tags=True>")
enforce_tags(cfg, save_to_file=True)
# pretty print config tree using Rich library
if cfg.extras.get("print_config"):
log.info("Printing config tree with Rich! <cfg.extras.print_config=True>")
print_config_tree(cfg, resolve=True, save_to_file=True)
def task_wrapper(task_func: Callable) -> Callable:
"""Optional decorator that controls the failure behavior when executing the task function.
This wrapper can be used to:
- make sure loggers are closed even if the task function raises an exception (prevents multirun failure)
- save the exception to a `.log` file
- mark the run as failed with a dedicated file in the `logs/` folder (so we can find and rerun it later)
- etc. (adjust depending on your needs)
Example:
```
@utils.task_wrapper
def train(cfg: DictConfig) -> Tuple[dict, dict]:
...
return metric_dict, object_dict
```
""" # noqa: E501
def wrap(cfg: DictConfig):
# execute the task
try:
metric_dict, object_dict = task_func(cfg=cfg)
# things to do if exception occurs
except Exception as ex:
# save exception to `.log` file
log.exception("")
# some hyperparameter combinations might be invalid or
# cause out-of-memory errors so when using hparam search
# plugins like Optuna, you might want to disable
# raising the below exception to avoid multirun failure
raise ex
# things to always do after either success or exception
finally:
# display output dir path in terminal
log.info(f"Output dir: {cfg.paths.run_dir}")
# always close wandb run (even if exception occurs so multirun won't fail)
if find_spec("wandb"): # check if wandb is installed
import wandb
if wandb.run:
log.info("Closing wandb!")
wandb.finish()
return metric_dict, object_dict
return wrap
def get_metric_value(metric_dict: dict, metric_name: str) -> float:
"""Safely retrieves value of the metric logged in LightningModule."""
if not metric_name:
log.info("Metric name is None! Skipping metric value retrieval...")
return None
if metric_name not in metric_dict:
raise Exception(
f"Metric value not found! <metric_name={metric_name}>\n"
"Make sure metric name logged in LightningModule is correct!\n"
"Make sure `optimized_metric` name in `hparams_search` config is correct!"
)
metric_value = metric_dict[metric_name].item()
log.info(f"Retrieved metric value! <{metric_name}={metric_value}>")
return metric_value
def set_seed(seed: int):
if seed < 0:
seed = -seed
if seed > (1 << 31):
seed = 1 << 31
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
if torch.backends.cudnn.is_available():
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
|