import gradio as gr import PIL.Image from pathlib import Path import pandas as pd from diffusers.pipelines import StableDiffusionPipeline import torch import argparse import os import warnings from safetensors.torch import load_file import yaml warnings.filterwarnings("ignore") OUTPUT_DIR = "OUTPUT" cuda_device = 1 device = f"cuda:{cuda_device}" if torch.cuda.is_available() else "cpu" TITLE = "Demo for Generating Chest X-rays using Diferent Parameter-Efficient Fine-Tuned Stable Diffusion Pipelines" INFO_ABOUT_TEXT_PROMPT = "INFO_ABOUT_TEXT_PROMPT" INFO_ABOUT_GUIDANCE_SCALE = "INFO_ABOUT_GUIDANCE_SCALE" INFO_ABOUT_INFERENCE_STEPS = "INFO_ABOUT_INFERENCE_STEPS" EXAMPLE_TEXT_PROMPTS = [ "No acute cardiopulmonary abnormality.", "Normal chest radiograph.", "No acute intrathoracic process.", "Mild pulmonary edema.", "No focal consolidation concerning for pneumonia", "No radiographic evidence for acute cardiopulmonary process", ] def load_adapted_unet(unet_pretraining_type, exp_path, pipe): """ Loads the adapted U-Net for the selected PEFT Type Parameters: unet_pretraining_type (str): The type of PEFT to use for generating the X-ray exp_path (str): The path to the best trained model for the selected PEFT Type pipe (StableDiffusionPipeline): The Stable Diffusion Pipeline to use for generating the X-ray Returns: None """ sd_folder_path = "runwayml/stable-diffusion-v1-5" if unet_pretraining_type == "freeze": pass elif unet_pretraining_type == "svdiff": print("SV-DIFF UNET") pipe.unet = load_unet_for_svdiff( sd_folder_path, spectral_shifts_ckpt=os.path.join( os.path.join(exp_path, "unet"), "spectral_shifts.safetensors" ), subfolder="unet", ) for module in pipe.unet.modules(): if hasattr(module, "perform_svd"): module.perform_svd() elif unet_pretraining_type == "lorav2": exp_path = os.path.join(exp_path, "pytorch_lora_weights.safetensors") pipe.unet.load_attn_procs(exp_path) else: exp_path = unet_pretraining_type + "_" + "diffusion_pytorch_model.safetensors" state_dict = load_file(exp_path) print(pipe.unet.load_state_dict(state_dict, strict=False)) def loadSDModel(unet_pretraining_type, exp_path, cuda_device): """ Loads the Stable Diffusion Model for the selected PEFT Type Parameters: unet_pretraining_type (str): The type of PEFT to use for generating the X-ray exp_path (str): The path to the best trained model for the selected PEFT Type cuda_device (str): The CUDA device to use for generating the X-ray Returns: pipe (StableDiffusionPipeline): The Stable Diffusion Pipeline to use for generating the X-ray """ sd_folder_path = "runwayml/stable-diffusion-v1-5" pipe = StableDiffusionPipeline.from_pretrained(sd_folder_path, revision="fp16") load_adapted_unet(unet_pretraining_type, exp_path, pipe) pipe.safety_checker = None return pipe def load_all_pipelines(): """ Loads all the Stable Diffusion Pipelines for each PEFT Type for efficient caching (Design Choice 2) Parameters: None Returns: sd_pipeline_full (StableDiffusionPipeline): The Stable Diffusion Pipeline for Full Fine-Tuning sd_pipeline_norm (StableDiffusionPipeline): The Stable Diffusion Pipeline for Norm Fine-Tuning sd_pipeline_bias (StableDiffusionPipeline): The Stable Diffusion Pipeline for Bias Fine-Tuning sd_pipeline_attention (StableDiffusionPipeline): The Stable Diffusion Pipeline for Attention Fine-Tuning sd_pipeline_NBA (StableDiffusionPipeline): The Stable Diffusion Pipeline for NBA Fine-Tuning sd_pipeline_difffit (StableDiffusionPipeline): The Stable Diffusion Pipeline for Difffit Fine-Tuning """ # Dictionary containing the path to the best trained models for each PEFT type MODEL_PATH_DICT = { "full": "full_diffusion_pytorch_model.safetensors", "norm": "norm_diffusion_pytorch_model.safetensors", "bias": "bias_diffusion_pytorch_model.safetensors", "attention": "attention_diffusion_pytorch_model.safetensors", "norm_bias_attention": "norm_bias_attention_diffusion_pytorch_model.safetensors", "difffit": "difffit_diffusion_pytorch_model.safetensors", } device = "0" cuda_device = f"cuda:{device}" if torch.cuda.is_available() else "cpu" # Full FT unet_pretraining_type = "full" print("Loading Pipeline for Full Fine-Tuning") sd_pipeline_full = loadSDModel( unet_pretraining_type=unet_pretraining_type, exp_path=MODEL_PATH_DICT[unet_pretraining_type], cuda_device=cuda_device, ) # Norm unet_pretraining_type = "norm" print("Loading Pipeline for Norm Fine-Tuning") sd_pipeline_norm = loadSDModel( unet_pretraining_type=unet_pretraining_type, exp_path=MODEL_PATH_DICT[unet_pretraining_type], cuda_device=cuda_device, ) # bias unet_pretraining_type = "bias" print("Loading Pipeline for Bias Fine-Tuning") sd_pipeline_bias = loadSDModel( unet_pretraining_type=unet_pretraining_type, exp_path=MODEL_PATH_DICT[unet_pretraining_type], cuda_device=cuda_device, ) # attention unet_pretraining_type = "attention" print("Loading Pipeline for Attention Fine-Tuning") sd_pipeline_attention = loadSDModel( unet_pretraining_type=unet_pretraining_type, exp_path=MODEL_PATH_DICT[unet_pretraining_type], cuda_device=cuda_device, ) # NBA unet_pretraining_type = "norm_bias_attention" print("Loading Pipeline for NBA Fine-Tuning") sd_pipeline_NBA = loadSDModel( unet_pretraining_type=unet_pretraining_type, exp_path=MODEL_PATH_DICT[unet_pretraining_type], cuda_device=cuda_device, ) # difffit unet_pretraining_type = "difffit" print("Loading Pipeline for Difffit Fine-Tuning") sd_pipeline_difffit = loadSDModel( unet_pretraining_type=unet_pretraining_type, exp_path=MODEL_PATH_DICT[unet_pretraining_type], cuda_device=cuda_device, ) return ( sd_pipeline_full, sd_pipeline_norm, sd_pipeline_bias, sd_pipeline_attention, sd_pipeline_NBA, sd_pipeline_difffit, ) # LOAD ALL PIPELINES FIRST AND CACHE THEM # ( # sd_pipeline_full, # sd_pipeline_norm, # sd_pipeline_bias, # sd_pipeline_attention, # sd_pipeline_NBA, # sd_pipeline_difffit, # ) = load_all_pipelines() # PIPELINE_DICT = { # "full": sd_pipeline_full, # "norm": sd_pipeline_norm, # "bias": sd_pipeline_bias, # "attention": sd_pipeline_attention, # "norm_bias_attention": sd_pipeline_NBA, # "difffit": sd_pipeline_difffit, # } def predict( unet_pretraining_type, input_text, guidance_scale=4, num_inference_steps=75, device="0", OUTPUT_DIR="OUTPUT", PIPELINE_DICT=PIPELINE_DICT, ): NUM_TUNABLE_PARAMS = { "full": 86, "attention": 26.7, "bias": 0.343, "norm": 0.2, "norm_bias_attention": 26.7, "lorav2": 0.8, "svdiff": 0.222, "difffit": 0.581, } cuda_device = f"cuda:{device}" if torch.cuda.is_available() else "cpu" #sd_pipeline = PIPELINE_DICT[unet_pretraining_type] print("Loading Pipeline for {} Fine-Tuning".format(unet_pretraining_type)) sd_pipeline_norm = loadSDModel( unet_pretraining_type=unet_pretraining_type, exp_path=MODEL_PATH_DICT[unet_pretraining_type], cuda_device=cuda_device, ) sd_pipeline.to(cuda_device) result_image = sd_pipeline( prompt=input_text, height=224, width=224, guidance_scale=guidance_scale, num_inference_steps=num_inference_steps, ) result_pil_image = result_image["images"][0] # Create a Bar Plot displaying the number of tunable parameters for the selected PEFT Type # Create a Pandas DataFrame df = pd.DataFrame( { "PEFT Type": list(NUM_TUNABLE_PARAMS.keys()), "Number of Tunable Parameters": list(NUM_TUNABLE_PARAMS.values()), } ) df = df[df["PEFT Type"].isin(["full", unet_pretraining_type])].reset_index( drop=True ) bar_plot = gr.BarPlot( value=df, x="PEFT Type", y="Number of Tunable Parameters", label="PEFT Type", title="Number of Tunable Parameters", vertical=False, ) return result_pil_image, bar_plot # Create a Gradio interface """ Input Parameters: 1. PEFT Type: (Dropdown) The type of PEFT to use for generating the X-ray 2. Input Text: (Textbox) The text prompt to use for generating the X-ray 3. Guidance Scale: (Slider) The guidance scale to use for generating the X-ray 4. Num Inference Steps: (Slider) The number of inference steps to use for generating the X-ray Output Parameters: 1. Generated X-ray Image: (Image) The generated X-ray image 2. Number of Tunable Parameters: (Bar Plot) The number of tunable parameters for the selected PEFT Type """ iface = gr.Interface( fn=predict, inputs=[ gr.Dropdown( ["full", "difffit", "svdiff", "norm", "bias", "attention"], label="PEFT Type", ), gr.Dropdown( EXAMPLE_TEXT_PROMPTS, info=INFO_ABOUT_TEXT_PROMPT, label="Input Text" ), gr.Slider( minimum=1, maximum=10, value=4, step=1, info=INFO_ABOUT_GUIDANCE_SCALE, label="Guidance Scale", ), gr.Slider( minimum=1, maximum=100, value=75, step=1, info=INFO_ABOUT_INFERENCE_STEPS, label="Num Inference Steps", ), ], outputs=[gr.Image(type="pil"), gr.BarPlot()], live=True, analytics_enabled=False, title=TITLE, ) # Launch the Gradio interface iface.launch(share=True)