File size: 2,398 Bytes
f85e212
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60

from pathlib import Path
import torch 
from torchvision import utils 
import math 
from medical_diffusion.models.pipelines import DiffusionPipeline

def rgb2gray(img):
    # img [B, C, H, W]
    return  ((0.3 * img[:,0]) + (0.59 * img[:,1]) + (0.11 * img[:,2]))[:, None]
    # return  ((0.33 * img[:,0]) + (0.33 * img[:,1]) + (0.33 * img[:,2]))[:, None]

def normalize(img):
    # img =  torch.stack([b.clamp(torch.quantile(b, 0.001), torch.quantile(b, 0.999)) for b in img])
    return torch.stack([(b-b.min())/(b.max()-b.min()) for b in img])

if __name__ == "__main__":
    path_out = Path.cwd()/'results/CheXpert/samples'
    path_out.mkdir(parents=True, exist_ok=True)

    torch.manual_seed(0)
    device = torch.device('cuda')

    # ------------ Load Model ------------
    # pipeline = DiffusionPipeline.load_best_checkpoint(path_run_dir)
    pipeline = DiffusionPipeline.load_from_checkpoint('runs/2022_12_12_171357_chest_diffusion/last.ckpt')
    pipeline.to(device)

    
    # --------- Generate Samples  -------------------
    steps = 150
    use_ddim = True 
    images = {}
    n_samples = 16

    for cond in [0,1,None]:
        torch.manual_seed(0)
 
        # --------- Conditioning ---------
        condition = torch.tensor([cond]*n_samples, device=device) if cond is not None else None 
        # un_cond = torch.tensor([1-cond]*n_samples, device=device)
        un_cond = None 

        # ----------- Run --------
        results = pipeline.sample(n_samples, (8, 32, 32), guidance_scale=8, condition=condition, un_cond=un_cond, steps=steps, use_ddim=use_ddim )
        # results = pipeline.sample(n_samples, (4, 64, 64), guidance_scale=1, condition=condition, un_cond=un_cond, steps=steps, use_ddim=use_ddim )

        # --------- Save result ---------------
        results = (results+1)/2  # Transform from [-1, 1] to [0, 1]
        results = results.clamp(0, 1)
        utils.save_image(results, path_out/f'test_{cond}.png', nrow=int(math.sqrt(results.shape[0])), normalize=True, scale_each=True) # For 2D images: [B, C, H, W]
        images[cond] = results


    diff = torch.abs(normalize(rgb2gray(images[1]))-normalize(rgb2gray(images[0]))) # [0,1] -> [0, 1]
    # diff = torch.abs(images[1]-images[0])
    utils.save_image(diff, path_out/'diff.png', nrow=int(math.sqrt(results.shape[0])), normalize=True, scale_each=True) # For 2D images: [B, C, H, W]