Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
from typing import List, Tuple | |
import hydra | |
from omegaconf import DictConfig | |
from lightning import LightningDataModule, LightningModule, Trainer, Callback | |
from deepscreen.utils.hydra import checkpoint_rerun_config | |
from deepscreen.utils import get_logger, job_wrapper, instantiate_callbacks | |
log = get_logger(__name__) | |
# def fix_dict_config(cfg: DictConfig): | |
# """fix all vars in the cfg config | |
# this is an in-place operation""" | |
# keys = list(cfg.keys()) | |
# for k in keys: | |
# if type(cfg[k]) is DictConfig: | |
# fix_dict_config(cfg[k]) | |
# else: | |
# setattr(cfg, k, getattr(cfg, k)) | |
def predict(cfg: DictConfig) -> Tuple[list, dict]: | |
"""Predict given checkpoint on a data predict set. | |
This method is wrapped in optional @job_wrapper decorator which applies extra utilities | |
before and after the call. | |
Args: | |
cfg (DictConfig): Configuration composed by Hydra. | |
Returns: | |
Tuple[dict, dict]: Dict with metrics and dict with all instantiated objects. | |
""" | |
log.info(f"Instantiating data <{cfg.data._target_}>") | |
datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data) | |
log.info(f"Instantiating model <{cfg.model._target_}>") | |
model: LightningModule = hydra.utils.instantiate(cfg.model) | |
log.info("Instantiating callbacks.") | |
callbacks: List[Callback] = instantiate_callbacks(cfg.get("callbacks")) | |
log.info(f"Instantiating trainer <{cfg.trainer._target_}>") | |
trainer: Trainer = hydra.utils.instantiate(cfg.trainer, logger=False, callbacks=callbacks) | |
object_dict = { | |
"cfg": cfg, | |
"datamodule": datamodule, | |
"model": model, | |
"callbacks": callbacks, | |
"trainer": trainer, | |
} | |
log.info("Start predicting.") | |
predictions = trainer.predict(model=model, datamodule=datamodule, | |
ckpt_path=cfg.ckpt_path, return_predictions=True) | |
return predictions, object_dict | |
def main(cfg: DictConfig): | |
assert cfg.ckpt_path, "Checkpoint path (`ckpt_path`) must be specified for predicting." | |
cfg = checkpoint_rerun_config(cfg) | |
predictions, _ = predict(cfg) | |
return predictions | |
if __name__ == "__main__": | |
main() | |