from pathlib import Path import torch from torchvision import utils import math from medical_diffusion.models.pipelines import DiffusionPipeline import numpy as np from PIL import Image import time def chunks(lst, n): """Yield successive n-sized chunks from lst.""" for i in range(0, len(lst), n): yield lst[i:i + n] # ------------ Load Model ------------ device = torch.device('cuda') # pipeline = DiffusionPipeline.load_best_checkpoint(path_run_dir) pipeline = DiffusionPipeline.load_from_checkpoint('runs/2022_12_12_171357_chest_diffusion/last.ckpt') pipeline.to(device) if __name__ == "__main__": # {'NRG':0, 'RG':1} 3270, {'MSIH':0, 'nonMSIH':1} :9979 {'No_Cardiomegaly':0, 'Cardiomegaly':1} 7869 for steps in [50, 100, 150, 200, 250]: for name, label in {'No_Cardiomegaly':0, 'Cardiomegaly':1}.items(): n_samples = 7869 sample_batch = 200 cfg = 1 # path_out = Path(f'/mnt/hdd/datasets/pathology/kather_msi_mss_2/synthetic_data/diffusion2_{steps}/')/name path_out = Path(f'/mnt/hdd/datasets/chest/CheXpert/ChecXpert-v10/generated_diffusion3_{steps}')/name # path_out = Path('/mnt/hdd/datasets/eye/AIROGS/data_generated_diffusion')/name path_out.mkdir(parents=True, exist_ok=True) # --------- Generate Samples ------------------- torch.manual_seed(0) counter = 0 for chunk in chunks(list(range(n_samples)), sample_batch): condition = torch.tensor([label]*len(chunk), device=device) if label is not None else None un_cond = torch.tensor([1-label]*len(chunk), device=device) if label is not None else None # Might be None, or 1-condition or specific label results = pipeline.sample(len(chunk), (8, 32, 32), guidance_scale=cfg, condition=condition, un_cond=un_cond, steps=steps) # results = pipeline.sample(len(chunk), (4, 64, 64), guidance_scale=cfg, condition=condition, un_cond=un_cond, steps=steps ) results = results.cpu().numpy() # --------- Save result ---------------- for image in results: image = image.clip(-1, 1) # or (image-image.min())/(image.max()-image.min()) image = (image+1)/2*255 # Transform from [-1, 1] to [0, 1] to [0, 255] image = np.moveaxis(image, 0, -1) image = image.astype(np.uint8) image = np.squeeze(image, axis=-1) if image.shape[-1]==1 else image Image.fromarray(image).convert("RGB").save(path_out/f'fake_{counter}.png') counter += 1 torch.cuda.empty_cache() time.sleep(3)