File size: 3,348 Bytes
f85e212
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121


from medical_diffusion.models.noise_schedulers import GaussianNoiseScheduler
from medical_diffusion.data.datasets import SimpleDataset2D, AIROGSDataset, CheXpert_Dataset, MSIvsMSS_2_Dataset
from medical_diffusion.models.embedders.latent_embedders import VAE, VAEGAN
import torch 
from pathlib import Path
import matplotlib.pyplot as plt 
import seaborn as sns 
from math import ceil


from torchvision.utils import save_image

# ds = SimpleDataset2D(
#     crawler_ext='jpg',
#     image_resize=(352, 528),
#     image_crop=(192, 288),
#     path_root='/home/gustav/Documents/datasets/AIROGS/dataset',
# )

# ds = AIROGSDataset(
#         crawler_ext='jpg',
#         image_resize=(256, 256),
#         image_crop=(256, 256),
#         path_root='/home/gustav/Documents/datasets/AIROGS/dataset', # '/home/gustav/Documents/datasets/AIROGS/dataset',  /mnt/hdd/datasets/eye/AIROGS/data/
#     )
# ds = CheXpert_Dataset(
#         crawler_ext='jpg',
#         augment_horizontal_flip=False,
#         augment_vertical_flip=False,
#         path_root='/mnt/hdd/datasets/chest/CheXpert/ChecXpert-v10/preprocessed/valid',
#     )

ds = MSIvsMSS_2_Dataset(
        crawler_ext='jpg',
        image_resize=None,
        image_crop=None,
        augment_horizontal_flip=False,
        augment_vertical_flip=False, 
        # path_root='/home/gustav/Documents/datasets/Kather_2/train',
        path_root='/mnt/hdd/datasets/pathology/kather_msi_mss_2/train/',
    )

device = torch.device('cuda')

scheduler = GaussianNoiseScheduler(timesteps=1000, beta_start=1e-4, schedule_strategy='scaled_linear')
# scheduler.to(device)
path_out = Path.cwd()/'results/test/scheduler'
path_out.mkdir(parents=True, exist_ok=True)


# print(scheduler.posterior_mean_coef1)
torch.manual_seed(0)
x_0 = ds[0]['source'][None] # [B, C, H, W]



embedder = VAE.load_from_checkpoint('runs/2022_11_25_232957_patho_vaegan/last_vae.ckpt')
with torch.no_grad():
    x_0 = embedder.encode(x_0)

# x_0 = (x_0-x_0.min())/(x_0.max()-x_0.min())
# x_0 = x_0*2-1
# x*2-1 = (x-0.5)*2

noise = torch.randn_like(x_0)

x_ts = [] 
step=100


for t in range(0, scheduler.T+step, step):
    t = torch.tensor([t]) 
    x_t = scheduler.estimate_x_t(x_0=x_0, t=t, x_T=noise) # [B, C, H, W]
    print(t, x_t.mean(), x_t.std())
    x_ts.append(x_t)

x_ts = torch.cat(x_ts)
# save_image(x_ts, path_out/'scheduler_nosing.png', normalize=True, scale_each=True)




binrange=(-2.5,2.5)
bins = 50

ncols=8
nelem = (scheduler.T+step)//step+2
nrows = ceil(nelem/8)
fig, ax = plt.subplots(nrows=nrows, ncols=ncols, figsize=(ncols*3, nrows*3))
ax_iter = iter(ax.flatten())



for axis in ax_iter:
    axis.spines['top'].set_visible(False)
    axis.spines['right'].set_visible(False)
    axis.spines['left'].set_visible(False)
    axis.axes.get_yaxis().set_visible(False)
ax_iter = iter(ax.flatten())

axis = next(ax_iter)
sns.histplot(x=x_0.flatten(), bins=bins, binrange=binrange, ax=axis)

for t in range(0, scheduler.T+step, step):
    print(t)
    t = torch.tensor([t]) 
    x_t = scheduler.estimate_x_t(x_0=x_0, t=t, x_T=noise) # [B, C, H, W]
    axis = next(ax_iter)
    sns.histplot(x=x_t.flatten(), bins=bins, binrange=binrange, ax=axis)

axis = next(ax_iter)
sns.histplot(x=noise.flatten(), bins=bins, binrange=binrange, ax=axis)

fig.tight_layout()
fig.savefig(path_out/'scheduler_nosing_histo.png')