Spaces:
Runtime error
Runtime error
import gradio as gr | |
import PIL.Image | |
from pathlib import Path | |
import pandas as pd | |
from diffusers.pipelines import StableDiffusionPipeline | |
import torch | |
import argparse | |
import os | |
import warnings | |
from safetensors.torch import load_file | |
import yaml | |
warnings.filterwarnings("ignore") | |
OUTPUT_DIR = "OUTPUT" | |
cuda_device = 1 | |
device = f"cuda:{cuda_device}" if torch.cuda.is_available() else "cpu" | |
TITLE = "Demo for Generating Chest X-rays using Diferent Parameter-Efficient Fine-Tuned Stable Diffusion Pipelines" | |
INFO_ABOUT_TEXT_PROMPT = "INFO_ABOUT_TEXT_PROMPT" | |
INFO_ABOUT_GUIDANCE_SCALE = "INFO_ABOUT_GUIDANCE_SCALE" | |
INFO_ABOUT_INFERENCE_STEPS = "INFO_ABOUT_INFERENCE_STEPS" | |
EXAMPLE_TEXT_PROMPTS = [ | |
"No acute cardiopulmonary abnormality.", | |
"Normal chest radiograph.", | |
"No acute intrathoracic process.", | |
"Mild pulmonary edema.", | |
"No focal consolidation concerning for pneumonia", | |
"No radiographic evidence for acute cardiopulmonary process", | |
] | |
def load_adapted_unet(unet_pretraining_type, exp_path, pipe): | |
""" | |
Loads the adapted U-Net for the selected PEFT Type | |
Parameters: | |
unet_pretraining_type (str): The type of PEFT to use for generating the X-ray | |
exp_path (str): The path to the best trained model for the selected PEFT Type | |
pipe (StableDiffusionPipeline): The Stable Diffusion Pipeline to use for generating the X-ray | |
Returns: | |
None | |
""" | |
sd_folder_path = "runwayml/stable-diffusion-v1-5" | |
if unet_pretraining_type == "freeze": | |
pass | |
elif unet_pretraining_type == "svdiff": | |
print("SV-DIFF UNET") | |
pipe.unet = load_unet_for_svdiff( | |
sd_folder_path, | |
spectral_shifts_ckpt=os.path.join( | |
os.path.join(exp_path, "unet"), "spectral_shifts.safetensors" | |
), | |
subfolder="unet", | |
) | |
for module in pipe.unet.modules(): | |
if hasattr(module, "perform_svd"): | |
module.perform_svd() | |
elif unet_pretraining_type == "lorav2": | |
exp_path = os.path.join(exp_path, "pytorch_lora_weights.safetensors") | |
pipe.unet.load_attn_procs(exp_path) | |
else: | |
exp_path = unet_pretraining_type + "_" + "diffusion_pytorch_model.safetensors" | |
state_dict = load_file(exp_path) | |
print(pipe.unet.load_state_dict(state_dict, strict=False)) | |
def loadSDModel(unet_pretraining_type, exp_path, cuda_device): | |
""" | |
Loads the Stable Diffusion Model for the selected PEFT Type | |
Parameters: | |
unet_pretraining_type (str): The type of PEFT to use for generating the X-ray | |
exp_path (str): The path to the best trained model for the selected PEFT Type | |
cuda_device (str): The CUDA device to use for generating the X-ray | |
Returns: | |
pipe (StableDiffusionPipeline): The Stable Diffusion Pipeline to use for generating the X-ray | |
""" | |
sd_folder_path = "runwayml/stable-diffusion-v1-5" | |
pipe = StableDiffusionPipeline.from_pretrained(sd_folder_path, revision="fp16") | |
load_adapted_unet(unet_pretraining_type, exp_path, pipe) | |
pipe.safety_checker = None | |
return pipe | |
def load_all_pipelines(): | |
""" | |
Loads all the Stable Diffusion Pipelines for each PEFT Type for efficient caching (Design Choice 2) | |
Parameters: | |
None | |
Returns: | |
sd_pipeline_full (StableDiffusionPipeline): The Stable Diffusion Pipeline for Full Fine-Tuning | |
sd_pipeline_norm (StableDiffusionPipeline): The Stable Diffusion Pipeline for Norm Fine-Tuning | |
sd_pipeline_bias (StableDiffusionPipeline): The Stable Diffusion Pipeline for Bias Fine-Tuning | |
sd_pipeline_attention (StableDiffusionPipeline): The Stable Diffusion Pipeline for Attention Fine-Tuning | |
sd_pipeline_NBA (StableDiffusionPipeline): The Stable Diffusion Pipeline for NBA Fine-Tuning | |
sd_pipeline_difffit (StableDiffusionPipeline): The Stable Diffusion Pipeline for Difffit Fine-Tuning | |
""" | |
# Dictionary containing the path to the best trained models for each PEFT type | |
MODEL_PATH_DICT = { | |
"full": "full_diffusion_pytorch_model.safetensors", | |
"norm": "norm_diffusion_pytorch_model.safetensors", | |
"bias": "bias_diffusion_pytorch_model.safetensors", | |
"attention": "attention_diffusion_pytorch_model.safetensors", | |
"norm_bias_attention": "norm_bias_attention_diffusion_pytorch_model.safetensors", | |
"difffit": "difffit_diffusion_pytorch_model.safetensors", | |
} | |
device = "0" | |
cuda_device = f"cuda:{device}" if torch.cuda.is_available() else "cpu" | |
# Full FT | |
unet_pretraining_type = "full" | |
print("Loading Pipeline for Full Fine-Tuning") | |
sd_pipeline_full = loadSDModel( | |
unet_pretraining_type=unet_pretraining_type, | |
exp_path=MODEL_PATH_DICT[unet_pretraining_type], | |
cuda_device=cuda_device, | |
) | |
# Norm | |
unet_pretraining_type = "norm" | |
print("Loading Pipeline for Norm Fine-Tuning") | |
sd_pipeline_norm = loadSDModel( | |
unet_pretraining_type=unet_pretraining_type, | |
exp_path=MODEL_PATH_DICT[unet_pretraining_type], | |
cuda_device=cuda_device, | |
) | |
# bias | |
unet_pretraining_type = "bias" | |
print("Loading Pipeline for Bias Fine-Tuning") | |
sd_pipeline_bias = loadSDModel( | |
unet_pretraining_type=unet_pretraining_type, | |
exp_path=MODEL_PATH_DICT[unet_pretraining_type], | |
cuda_device=cuda_device, | |
) | |
# attention | |
unet_pretraining_type = "attention" | |
print("Loading Pipeline for Attention Fine-Tuning") | |
sd_pipeline_attention = loadSDModel( | |
unet_pretraining_type=unet_pretraining_type, | |
exp_path=MODEL_PATH_DICT[unet_pretraining_type], | |
cuda_device=cuda_device, | |
) | |
# NBA | |
unet_pretraining_type = "norm_bias_attention" | |
print("Loading Pipeline for NBA Fine-Tuning") | |
sd_pipeline_NBA = loadSDModel( | |
unet_pretraining_type=unet_pretraining_type, | |
exp_path=MODEL_PATH_DICT[unet_pretraining_type], | |
cuda_device=cuda_device, | |
) | |
# difffit | |
unet_pretraining_type = "difffit" | |
print("Loading Pipeline for Difffit Fine-Tuning") | |
sd_pipeline_difffit = loadSDModel( | |
unet_pretraining_type=unet_pretraining_type, | |
exp_path=MODEL_PATH_DICT[unet_pretraining_type], | |
cuda_device=cuda_device, | |
) | |
return ( | |
sd_pipeline_full, | |
sd_pipeline_norm, | |
sd_pipeline_bias, | |
sd_pipeline_attention, | |
sd_pipeline_NBA, | |
sd_pipeline_difffit, | |
) | |
# LOAD ALL PIPELINES FIRST AND CACHE THEM | |
# ( | |
# sd_pipeline_full, | |
# sd_pipeline_norm, | |
# sd_pipeline_bias, | |
# sd_pipeline_attention, | |
# sd_pipeline_NBA, | |
# sd_pipeline_difffit, | |
# ) = load_all_pipelines() | |
# PIPELINE_DICT = { | |
# "full": sd_pipeline_full, | |
# "norm": sd_pipeline_norm, | |
# "bias": sd_pipeline_bias, | |
# "attention": sd_pipeline_attention, | |
# "norm_bias_attention": sd_pipeline_NBA, | |
# "difffit": sd_pipeline_difffit, | |
# } | |
def predict( | |
unet_pretraining_type, | |
input_text, | |
guidance_scale=4, | |
num_inference_steps=75, | |
device="0", | |
OUTPUT_DIR="OUTPUT", | |
PIPELINE_DICT=PIPELINE_DICT, | |
): | |
NUM_TUNABLE_PARAMS = { | |
"full": 86, | |
"attention": 26.7, | |
"bias": 0.343, | |
"norm": 0.2, | |
"norm_bias_attention": 26.7, | |
"lorav2": 0.8, | |
"svdiff": 0.222, | |
"difffit": 0.581, | |
} | |
cuda_device = f"cuda:{device}" if torch.cuda.is_available() else "cpu" | |
#sd_pipeline = PIPELINE_DICT[unet_pretraining_type] | |
print("Loading Pipeline for {} Fine-Tuning".format(unet_pretraining_type)) | |
sd_pipeline_norm = loadSDModel( | |
unet_pretraining_type=unet_pretraining_type, | |
exp_path=MODEL_PATH_DICT[unet_pretraining_type], | |
cuda_device=cuda_device, | |
) | |
sd_pipeline.to(cuda_device) | |
result_image = sd_pipeline( | |
prompt=input_text, | |
height=224, | |
width=224, | |
guidance_scale=guidance_scale, | |
num_inference_steps=num_inference_steps, | |
) | |
result_pil_image = result_image["images"][0] | |
# Create a Bar Plot displaying the number of tunable parameters for the selected PEFT Type | |
# Create a Pandas DataFrame | |
df = pd.DataFrame( | |
{ | |
"PEFT Type": list(NUM_TUNABLE_PARAMS.keys()), | |
"Number of Tunable Parameters": list(NUM_TUNABLE_PARAMS.values()), | |
} | |
) | |
df = df[df["PEFT Type"].isin(["full", unet_pretraining_type])].reset_index( | |
drop=True | |
) | |
bar_plot = gr.BarPlot( | |
value=df, | |
x="PEFT Type", | |
y="Number of Tunable Parameters", | |
label="PEFT Type", | |
title="Number of Tunable Parameters", | |
vertical=False, | |
) | |
return result_pil_image, bar_plot | |
# Create a Gradio interface | |
""" | |
Input Parameters: | |
1. PEFT Type: (Dropdown) The type of PEFT to use for generating the X-ray | |
2. Input Text: (Textbox) The text prompt to use for generating the X-ray | |
3. Guidance Scale: (Slider) The guidance scale to use for generating the X-ray | |
4. Num Inference Steps: (Slider) The number of inference steps to use for generating the X-ray | |
Output Parameters: | |
1. Generated X-ray Image: (Image) The generated X-ray image | |
2. Number of Tunable Parameters: (Bar Plot) The number of tunable parameters for the selected PEFT Type | |
""" | |
iface = gr.Interface( | |
fn=predict, | |
inputs=[ | |
gr.Dropdown( | |
["full", "difffit", "svdiff", "norm", "bias", "attention"], | |
label="PEFT Type", | |
), | |
gr.Dropdown( | |
EXAMPLE_TEXT_PROMPTS, info=INFO_ABOUT_TEXT_PROMPT, label="Input Text" | |
), | |
gr.Slider( | |
minimum=1, | |
maximum=10, | |
value=4, | |
step=1, | |
info=INFO_ABOUT_GUIDANCE_SCALE, | |
label="Guidance Scale", | |
), | |
gr.Slider( | |
minimum=1, | |
maximum=100, | |
value=75, | |
step=1, | |
info=INFO_ABOUT_INFERENCE_STEPS, | |
label="Num Inference Steps", | |
), | |
], | |
outputs=[gr.Image(type="pil"), gr.BarPlot()], | |
live=True, | |
analytics_enabled=False, | |
title=TITLE, | |
) | |
# Launch the Gradio interface | |
iface.launch(share=True) | |