import logging from lightning.pytorch.utilities import rank_zero_only from lightning.pytorch.utilities.model_summary import ModelSummary def get_logger(name=__name__) -> logging.Logger: """Initializes multi-GPU-friendly python command line logger.""" logger = logging.getLogger(name) # this ensures all logging levels get marked with the rank zero decorator # otherwise logs would get multiplied for each GPU process in multi-GPU setup logging_levels = ("debug", "info", "warning", "error", "exception", "fatal", "critical") for level in logging_levels: setattr(logger, level, rank_zero_only(getattr(logger, level))) return logger log = get_logger(__name__) @rank_zero_only def log_hyperparameters(object_dict: dict) -> None: """Controls which config parts are saved by lightning loggers. Additionally, saves: - Number of model parameters """ hparams = {} cfg = object_dict["cfg"] model = object_dict["model"] trainer = object_dict["trainer"] if not trainer.logger: log.warning("Logger not found! Skipping hyperparameter logging.") return hparams["model"] = cfg["model"] # save number of model parameters model_summary = ModelSummary(model) hparams["model/params/total"] = model_summary.total_parameters hparams["model/params/trainable"] = model_summary.trainable_parameters hparams["model/params/non_trainable"] = model_summary.total_parameters - model_summary.trainable_parameters hparams["data"] = cfg["data"] hparams["trainer"] = cfg["trainer"] hparams["callbacks"] = cfg.get("callbacks") hparams["extras"] = cfg.get("extras") hparams["job_name"] = cfg.get("job_name") hparams["tags"] = cfg.get("tags") hparams["ckpt_path"] = cfg.get("ckpt_path") hparams["seed"] = cfg.get("seed") # send hparams to all loggers for logger in trainer.loggers: logger.log_hyperparams(hparams)