libokj's picture
Upload 110 files
c0ec7e6
raw
history blame
2.34 kB
from typing import List, Tuple
import hydra
from omegaconf import DictConfig
from lightning import LightningDataModule, LightningModule, Trainer, Callback
from deepscreen.utils.hydra import checkpoint_rerun_config
from deepscreen.utils import get_logger, job_wrapper, instantiate_callbacks
log = get_logger(__name__)
# def fix_dict_config(cfg: DictConfig):
# """fix all vars in the cfg config
# this is an in-place operation"""
# keys = list(cfg.keys())
# for k in keys:
# if type(cfg[k]) is DictConfig:
# fix_dict_config(cfg[k])
# else:
# setattr(cfg, k, getattr(cfg, k))
@job_wrapper(extra_utils=True)
def predict(cfg: DictConfig) -> Tuple[list, dict]:
"""Predict given checkpoint on a data predict set.
This method is wrapped in optional @job_wrapper decorator which applies extra utilities
before and after the call.
Args:
cfg (DictConfig): Configuration composed by Hydra.
Returns:
Tuple[dict, dict]: Dict with metrics and dict with all instantiated objects.
"""
log.info(f"Instantiating data <{cfg.data._target_}>")
datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data)
log.info(f"Instantiating model <{cfg.model._target_}>")
model: LightningModule = hydra.utils.instantiate(cfg.model)
log.info("Instantiating callbacks.")
callbacks: List[Callback] = instantiate_callbacks(cfg.get("callbacks"))
log.info(f"Instantiating trainer <{cfg.trainer._target_}>")
trainer: Trainer = hydra.utils.instantiate(cfg.trainer, logger=False, callbacks=callbacks)
object_dict = {
"cfg": cfg,
"datamodule": datamodule,
"model": model,
"callbacks": callbacks,
"trainer": trainer,
}
log.info("Start predicting.")
predictions = trainer.predict(model=model, datamodule=datamodule,
ckpt_path=cfg.ckpt_path, return_predictions=True)
return predictions, object_dict
@hydra.main(version_base="1.3", config_path="../configs", config_name="predict.yaml")
def main(cfg: DictConfig):
assert cfg.ckpt_path, "Checkpoint path (`ckpt_path`) must be specified for predicting."
cfg = checkpoint_rerun_config(cfg)
predictions, _ = predict(cfg)
return predictions
if __name__ == "__main__":
main()