from pathlib import Path import math import torch import torch.nn.functional as F from torchvision.utils import save_image from medical_diffusion.data.datamodules import SimpleDataModule from medical_diffusion.data.datasets import AIROGSDataset, MSIvsMSS_2_Dataset, CheXpert_2_Dataset from medical_diffusion.models.embedders.latent_embedders import VQVAE, VQGAN, VAE, VAEGAN import matplotlib.pyplot as plt import seaborn as sns path_out = Path.cwd()/'results/test/latent_embedder' path_out.mkdir(parents=True, exist_ok=True) device = torch.device('cuda') torch.manual_seed(0) # ds = AIROGSDataset( # 256x256 # crawler_ext='jpg', # augment_horizontal_flip=True, # augment_vertical_flip=True, # # path_root='/home/gustav/Documents/datasets/AIROGS/dataset', # path_root='/mnt/hdd/datasets/eye/AIROGS/data_256x256', # ) # ds = MSIvsMSS_2_Dataset( # 512x512 # # image_resize=256, # crawler_ext='jpg', # augment_horizontal_flip=False, # augment_vertical_flip=False, # # path_root='/home/gustav/Documents/datasets/Kather_2/train' # path_root='/mnt/hdd/datasets/pathology/kather_msi_mss_2/train/' # ) ds = CheXpert_2_Dataset( # 256x256 augment_horizontal_flip=False, augment_vertical_flip=False, path_root = '/mnt/hdd/datasets/chest/CheXpert/ChecXpert-v10/preprocessed_tianyu' ) dm = SimpleDataModule( ds_train = ds, batch_size=4, num_workers=0, ) # ------------------ Load Model ------------------- model = VAE.load_from_checkpoint('runs/2022_12_12_133315_chest_vaegan/last_vae.ckpt') # from diffusers import StableDiffusionPipeline # with open('auth_token.txt', 'r') as file: # auth_token = file.read() # pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float32, use_auth_token=auth_token) # model = pipe.vae model = model.to(device) # ------------- Reset Seed ------------ torch.manual_seed(0) # ------------ Prepare Data ---------------- date_iter = iter(dm.train_dataloader()) for k in range(1): batch = next(date_iter) x = batch['source'] x = x.to(device) #.to(torch.float16) # ------------- Run Model ---------------- with torch.no_grad(): # ------------- Encode ---------- z = model.encode(x) # z = z.latent_dist.sample() # Only for stable-diffusion # ------------- Decode ----------- sns.histplot(z.flatten().detach().cpu().numpy()) plt.savefig('test.png') x_pred = model.decode(z) # x_pred = x_pred.sample # Only for stable-diffusion x_pred = x_pred.clamp(-1, 1) images = x_pred[0] #torch.cat([x, x_pred]) save_image(images, path_out/'latent_embedder_vaegan.png', nrow=x.shape[0], normalize=True, scale_each=True)