from typing import List, Tuple import hydra from omegaconf import DictConfig from lightning import LightningDataModule, LightningModule, Trainer, Callback from deepscreen.utils.hydra import checkpoint_rerun_config from deepscreen.utils import get_logger, job_wrapper, instantiate_callbacks log = get_logger(__name__) # def fix_dict_config(cfg: DictConfig): # """fix all vars in the cfg config # this is an in-place operation""" # keys = list(cfg.keys()) # for k in keys: # if type(cfg[k]) is DictConfig: # fix_dict_config(cfg[k]) # else: # setattr(cfg, k, getattr(cfg, k)) @job_wrapper(extra_utils=True) def predict(cfg: DictConfig) -> Tuple[list, dict]: """Predict given checkpoint on a data predict set. This method is wrapped in optional @job_wrapper decorator which applies extra utilities before and after the call. Args: cfg (DictConfig): Configuration composed by Hydra. Returns: Tuple[dict, dict]: Dict with metrics and dict with all instantiated objects. """ log.info(f"Instantiating data <{cfg.data._target_}>") datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data) log.info(f"Instantiating model <{cfg.model._target_}>") model: LightningModule = hydra.utils.instantiate(cfg.model) log.info("Instantiating callbacks.") callbacks: List[Callback] = instantiate_callbacks(cfg.get("callbacks")) log.info(f"Instantiating trainer <{cfg.trainer._target_}>") trainer: Trainer = hydra.utils.instantiate(cfg.trainer, logger=False, callbacks=callbacks) object_dict = { "cfg": cfg, "datamodule": datamodule, "model": model, "callbacks": callbacks, "trainer": trainer, } log.info("Start predicting.") predictions = trainer.predict(model=model, datamodule=datamodule, ckpt_path=cfg.ckpt_path, return_predictions=True) return predictions, object_dict @hydra.main(version_base="1.3", config_path="../configs", config_name="predict.yaml") def main(cfg: DictConfig): assert cfg.ckpt_path, "Checkpoint path (`ckpt_path`) must be specified for predicting." cfg = checkpoint_rerun_config(cfg) predictions, _ = predict(cfg) return predictions if __name__ == "__main__": main()