|
from typing import List
|
|
|
|
import hydra
|
|
from omegaconf import DictConfig
|
|
from pytorch_lightning import Callback
|
|
from pytorch_lightning.loggers import Logger
|
|
|
|
from .logger import RankedLogger
|
|
|
|
log = RankedLogger(__name__, rank_zero_only=True)
|
|
|
|
|
|
def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]:
|
|
"""Instantiates callbacks from config."""
|
|
|
|
callbacks: List[Callback] = []
|
|
|
|
if not callbacks_cfg:
|
|
log.warning("No callback configs found! Skipping..")
|
|
return callbacks
|
|
|
|
if not isinstance(callbacks_cfg, DictConfig):
|
|
raise TypeError("Callbacks config must be a DictConfig!")
|
|
|
|
for _, cb_conf in callbacks_cfg.items():
|
|
if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf:
|
|
log.info(f"Instantiating callback <{cb_conf._target_}>")
|
|
callbacks.append(hydra.utils.instantiate(cb_conf))
|
|
|
|
return callbacks
|
|
|
|
|
|
def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]:
|
|
"""Instantiates loggers from config."""
|
|
|
|
logger: List[Logger] = []
|
|
|
|
if not logger_cfg:
|
|
log.warning("No logger configs found! Skipping...")
|
|
return logger
|
|
|
|
if not isinstance(logger_cfg, DictConfig):
|
|
raise TypeError("Logger config must be a DictConfig!")
|
|
|
|
for _, lg_conf in logger_cfg.items():
|
|
if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf:
|
|
log.info(f"Instantiating logger <{lg_conf._target_}>")
|
|
logger.append(hydra.utils.instantiate(lg_conf))
|
|
|
|
return logger
|
|
|