File size: 2,791 Bytes
f85e212
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58

from pathlib import Path
import torch 
from torchvision import utils 
import math 
from medical_diffusion.models.pipelines import DiffusionPipeline
import numpy as np 
from PIL import Image
import time

def chunks(lst, n):
    """Yield successive n-sized chunks from lst."""
    for i in range(0, len(lst), n):
        yield lst[i:i + n]

# ------------ Load Model ------------
device = torch.device('cuda')
# pipeline = DiffusionPipeline.load_best_checkpoint(path_run_dir)
pipeline = DiffusionPipeline.load_from_checkpoint('runs/2022_12_12_171357_chest_diffusion/last.ckpt')
pipeline.to(device)

if __name__ == "__main__":
    # {'NRG':0, 'RG':1} 3270, {'MSIH':0, 'nonMSIH':1} :9979 {'No_Cardiomegaly':0, 'Cardiomegaly':1} 7869
    for steps in [50, 100, 150, 200, 250]:
        for name, label in  {'No_Cardiomegaly':0, 'Cardiomegaly':1}.items(): 
            n_samples = 7869
            sample_batch = 200
            cfg = 1
        
            # path_out = Path(f'/mnt/hdd/datasets/pathology/kather_msi_mss_2/synthetic_data/diffusion2_{steps}/')/name
            path_out = Path(f'/mnt/hdd/datasets/chest/CheXpert/ChecXpert-v10/generated_diffusion3_{steps}')/name
            # path_out = Path('/mnt/hdd/datasets/eye/AIROGS/data_generated_diffusion')/name
            path_out.mkdir(parents=True, exist_ok=True)

            # --------- Generate Samples  -------------------
            torch.manual_seed(0)
            counter = 0
            for chunk in chunks(list(range(n_samples)), sample_batch):
                condition = torch.tensor([label]*len(chunk), device=device) if label is not None else None 
                un_cond = torch.tensor([1-label]*len(chunk), device=device)  if label is not None else None # Might be None, or 1-condition or specific label 
                results = pipeline.sample(len(chunk), (8, 32, 32), guidance_scale=cfg, condition=condition, un_cond=un_cond, steps=steps)
                # results = pipeline.sample(len(chunk), (4, 64, 64), guidance_scale=cfg, condition=condition, un_cond=un_cond, steps=steps )

                results = results.cpu().numpy()
                # --------- Save result ----------------
                for image in results:
                    image =  image.clip(-1, 1) # or (image-image.min())/(image.max()-image.min()) 
                    image = (image+1)/2*255  # Transform from [-1, 1] to [0, 1] to [0, 255]
                    image = np.moveaxis(image, 0, -1)
                    image = image.astype(np.uint8)
                    image = np.squeeze(image, axis=-1) if image.shape[-1]==1 else image 
                    Image.fromarray(image).convert("RGB").save(path_out/f'fake_{counter}.png')
                    counter += 1

            
            torch.cuda.empty_cache()
            time.sleep(3)