File size: 2,710 Bytes
f85e212
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
from pathlib import Path
import math 

import torch 
import torch.nn.functional as F
from torchvision.utils import save_image

from medical_diffusion.data.datamodules import SimpleDataModule
from medical_diffusion.data.datasets import AIROGSDataset, MSIvsMSS_2_Dataset, CheXpert_2_Dataset
from medical_diffusion.models.embedders.latent_embedders import VQVAE, VQGAN, VAE, VAEGAN
import matplotlib.pyplot as plt 
import seaborn as sns 

path_out = Path.cwd()/'results/test/latent_embedder'
path_out.mkdir(parents=True, exist_ok=True)
device = torch.device('cuda')
torch.manual_seed(0)

# ds = AIROGSDataset( #  256x256
#     crawler_ext='jpg',
#     augment_horizontal_flip=True,
#     augment_vertical_flip=True,
#     # path_root='/home/gustav/Documents/datasets/AIROGS/dataset',
#     path_root='/mnt/hdd/datasets/eye/AIROGS/data_256x256',
# )

# ds = MSIvsMSS_2_Dataset( #  512x512
#     # image_resize=256,
#     crawler_ext='jpg',
#     augment_horizontal_flip=False,
#     augment_vertical_flip=False,
#     # path_root='/home/gustav/Documents/datasets/Kather_2/train'
#     path_root='/mnt/hdd/datasets/pathology/kather_msi_mss_2/train/'
# )

ds = CheXpert_2_Dataset( #  256x256
    augment_horizontal_flip=False,
    augment_vertical_flip=False,
    path_root = '/mnt/hdd/datasets/chest/CheXpert/ChecXpert-v10/preprocessed_tianyu'
)

dm = SimpleDataModule(
    ds_train = ds,
    batch_size=4, 
    num_workers=0,
) 


# ------------------ Load Model -------------------
model = VAE.load_from_checkpoint('runs/2022_12_12_133315_chest_vaegan/last_vae.ckpt')

# from diffusers import StableDiffusionPipeline
# with open('auth_token.txt', 'r') as file:
#     auth_token = file.read()
# pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float32,  use_auth_token=auth_token)
# model = pipe.vae

model = model.to(device)

# ------------- Reset Seed ------------
torch.manual_seed(0)

# ------------ Prepare Data ----------------
date_iter = iter(dm.train_dataloader())
for k in range(1):
    batch = next(date_iter) 
x = batch['source']
x = x.to(device) #.to(torch.float16)

# ------------- Run Model ----------------
with torch.no_grad():
    # ------------- Encode ----------
    z = model.encode(x)
    # z = z.latent_dist.sample() # Only for stable-diffusion 

    # ------------- Decode -----------
    sns.histplot(z.flatten().detach().cpu().numpy())
    plt.savefig('test.png')
    x_pred = model.decode(z)
    # x_pred = x_pred.sample # Only for stable-diffusion 
    x_pred = x_pred.clamp(-1, 1)

images =  x_pred[0] #torch.cat([x, x_pred])
save_image(images, path_out/'latent_embedder_vaegan.png', nrow=x.shape[0], normalize=True, scale_each=True)