medfusion-app / scripts /helpers /dump_discrimnator.py
mueller-franzes's picture
init
f85e212
raw
history blame
890 Bytes
from pathlib import Path
import torch
from medical_diffusion.models.embedders.latent_embedders import VQVAE, VQGAN, VAE, VAEGAN
from pytorch_lightning.trainer import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
path_root = Path('runs/2022_12_01_210017_patho_vaegan')
# Load model
model = VAEGAN.load_from_checkpoint(path_root/'last.ckpt')
# model = torch.load(path_root/'last.ckpt')
# Save model-part
# torch.save(model.vqvae, path_root/'last_vae.ckpt') # Not working
# ------ Ugly workaround ----------
checkpointing = ModelCheckpoint()
trainer = Trainer(callbacks=[checkpointing])
trainer.strategy._lightning_module = model.vqvae
trainer.model = model.vqvae
trainer.save_checkpoint(path_root/'last_vae.ckpt')
# -----------------
model = VAE.load_from_checkpoint(path_root/'last_vae.ckpt')
# model = torch.load(path_root/'last_vae.ckpt') # load_state_dict