Spaces:
Runtime error
Runtime error
from pathlib import Path | |
import math | |
import torch | |
import torch.nn.functional as F | |
from torchvision.utils import save_image | |
from medical_diffusion.data.datamodules import SimpleDataModule | |
from medical_diffusion.data.datasets import AIROGSDataset, MSIvsMSS_2_Dataset, CheXpert_2_Dataset | |
from medical_diffusion.models.embedders.latent_embedders import VQVAE, VQGAN, VAE, VAEGAN | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
path_out = Path.cwd()/'results/test/latent_embedder' | |
path_out.mkdir(parents=True, exist_ok=True) | |
device = torch.device('cuda') | |
torch.manual_seed(0) | |
# ds = AIROGSDataset( # 256x256 | |
# crawler_ext='jpg', | |
# augment_horizontal_flip=True, | |
# augment_vertical_flip=True, | |
# # path_root='/home/gustav/Documents/datasets/AIROGS/dataset', | |
# path_root='/mnt/hdd/datasets/eye/AIROGS/data_256x256', | |
# ) | |
# ds = MSIvsMSS_2_Dataset( # 512x512 | |
# # image_resize=256, | |
# crawler_ext='jpg', | |
# augment_horizontal_flip=False, | |
# augment_vertical_flip=False, | |
# # path_root='/home/gustav/Documents/datasets/Kather_2/train' | |
# path_root='/mnt/hdd/datasets/pathology/kather_msi_mss_2/train/' | |
# ) | |
ds = CheXpert_2_Dataset( # 256x256 | |
augment_horizontal_flip=False, | |
augment_vertical_flip=False, | |
path_root = '/mnt/hdd/datasets/chest/CheXpert/ChecXpert-v10/preprocessed_tianyu' | |
) | |
dm = SimpleDataModule( | |
ds_train = ds, | |
batch_size=4, | |
num_workers=0, | |
) | |
# ------------------ Load Model ------------------- | |
model = VAE.load_from_checkpoint('runs/2022_12_12_133315_chest_vaegan/last_vae.ckpt') | |
# from diffusers import StableDiffusionPipeline | |
# with open('auth_token.txt', 'r') as file: | |
# auth_token = file.read() | |
# pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float32, use_auth_token=auth_token) | |
# model = pipe.vae | |
model = model.to(device) | |
# ------------- Reset Seed ------------ | |
torch.manual_seed(0) | |
# ------------ Prepare Data ---------------- | |
date_iter = iter(dm.train_dataloader()) | |
for k in range(1): | |
batch = next(date_iter) | |
x = batch['source'] | |
x = x.to(device) #.to(torch.float16) | |
# ------------- Run Model ---------------- | |
with torch.no_grad(): | |
# ------------- Encode ---------- | |
z = model.encode(x) | |
# z = z.latent_dist.sample() # Only for stable-diffusion | |
# ------------- Decode ----------- | |
sns.histplot(z.flatten().detach().cpu().numpy()) | |
plt.savefig('test.png') | |
x_pred = model.decode(z) | |
# x_pred = x_pred.sample # Only for stable-diffusion | |
x_pred = x_pred.clamp(-1, 1) | |
images = x_pred[0] #torch.cat([x, x_pred]) | |
save_image(images, path_out/'latent_embedder_vaegan.png', nrow=x.shape[0], normalize=True, scale_each=True) | |