Spaces:
Runtime error
Runtime error
import gradio as gr | |
import PIL.Image | |
from pathlib import Path | |
import pandas as pd | |
from diffusers.pipelines import StableDiffusionPipeline | |
import torch | |
import argparse | |
import os | |
import warnings | |
from safetensors.torch import load_file | |
import yaml | |
warnings.filterwarnings("ignore") | |
OUTPUT_DIR = "OUTPUT" | |
cuda_device = 1 | |
device = f"cuda:{cuda_device}" if torch.cuda.is_available() else "cpu" | |
print("DEVICE: ", device) | |
TITLE = "Demo for Generating Chest X-rays using Diferent Parameter-Efficient Fine-Tuned Stable Diffusion Pipelines" | |
INFO_ABOUT_TEXT_PROMPT = "INFO_ABOUT_TEXT_PROMPT" | |
INFO_ABOUT_GUIDANCE_SCALE = "INFO_ABOUT_GUIDANCE_SCALE" | |
INFO_ABOUT_INFERENCE_STEPS = "INFO_ABOUT_INFERENCE_STEPS" | |
EXAMPLE_TEXT_PROMPTS = [ | |
"No acute cardiopulmonary abnormality.", | |
"Normal chest radiograph.", | |
"No acute intrathoracic process.", | |
"Mild pulmonary edema.", | |
"No focal consolidation concerning for pneumonia", | |
"No radiographic evidence for acute cardiopulmonary process", | |
] | |
def load_adapted_unet(unet_pretraining_type, pipe): | |
""" | |
Loads the adapted U-Net for the selected PEFT Type | |
Parameters: | |
unet_pretraining_type (str): The type of PEFT to use for generating the X-ray | |
pipe (StableDiffusionPipeline): The Stable Diffusion Pipeline to use for generating the X-ray | |
Returns: | |
None | |
""" | |
sd_folder_path = "runwayml/stable-diffusion-v1-5" | |
exp_path = '' | |
if unet_pretraining_type == "freeze": | |
pass | |
elif unet_pretraining_type == "svdiff": | |
print("SV-DIFF UNET") | |
pipe.unet = load_unet_for_svdiff( | |
sd_folder_path, | |
spectral_shifts_ckpt=os.path.join( | |
os.path.join(exp_path, "unet"), "spectral_shifts.safetensors" | |
), | |
subfolder="unet", | |
) | |
for module in pipe.unet.modules(): | |
if hasattr(module, "perform_svd"): | |
module.perform_svd() | |
elif unet_pretraining_type == "lorav2": | |
exp_path = os.path.join(exp_path, "pytorch_lora_weights.safetensors") | |
pipe.unet.load_attn_procs(exp_path) | |
else: | |
# exp_path = unet_pretraining_type + "_" + "diffusion_pytorch_model.safetensors" | |
# state_dict = load_file(exp_path) | |
state_dict = load_file(unet_pretraining_type + "_" + "diffusion_pytorch_model.safetensors") | |
print(pipe.unet.load_state_dict(state_dict, strict=False)) | |
def loadSDModel(unet_pretraining_type, cuda_device): | |
""" | |
Loads the Stable Diffusion Model for the selected PEFT Type | |
Parameters: | |
unet_pretraining_type (str): The type of PEFT to use for generating the X-ray | |
cuda_device (str): The CUDA device to use for generating the X-ray | |
Returns: | |
pipe (StableDiffusionPipeline): The Stable Diffusion Pipeline to use for generating the X-ray | |
""" | |
sd_folder_path = "runwayml/stable-diffusion-v1-5" | |
pipe = StableDiffusionPipeline.from_pretrained(sd_folder_path, revision="fp16") | |
load_adapted_unet(unet_pretraining_type, pipe) | |
pipe.safety_checker = None | |
return pipe | |
def _predict_using_default_params(): | |
# Defining the default parameters | |
unet_pretraining_type = 'full' | |
input_text = 'No acute cardiopulmonary abnormality.' | |
guidance_scale = 4 | |
num_inference_steps = 75 | |
device = '0' | |
OUTPUT_DIR = 'OUTPUT' | |
BARPLOT_TITLE = "Tunable Parameters for {} Fine-Tuning".format(unet_pretraining_type) | |
NUM_TUNABLE_PARAMS = { | |
"full": 86, | |
"attention": 26.7, | |
"bias": 0.343, | |
"norm": 0.2, | |
"norm_bias_attention": 26.7, | |
"lorav2": 0.8, | |
"svdiff": 0.222, | |
"difffit": 0.581, | |
} | |
cuda_device = f"cuda:{device}" if torch.cuda.is_available() else "cpu" | |
print("Loading Pipeline for {} Fine-Tuning".format(unet_pretraining_type)) | |
sd_pipeline = loadSDModel( | |
unet_pretraining_type=unet_pretraining_type, | |
cuda_device=cuda_device, | |
) | |
sd_pipeline.to(cuda_device) | |
result_image = sd_pipeline( | |
prompt=input_text, | |
height=224, | |
width=224, | |
guidance_scale=guidance_scale, | |
num_inference_steps=num_inference_steps, | |
) | |
result_pil_image = result_image["images"][0] | |
# Create a Bar Plot displaying the number of tunable parameters for the selected PEFT Type | |
df = pd.DataFrame( | |
{ | |
"Fine-Tuning Strategy": list(NUM_TUNABLE_PARAMS.keys()), | |
"Number of Tunable Parameters": list(NUM_TUNABLE_PARAMS.values()), | |
} | |
) | |
print(df) | |
df = df[df["Fine-Tuning Strategy"].isin(["full", unet_pretraining_type])].reset_index( | |
drop=True | |
) | |
bar_plot = gr.BarPlot( | |
value=df, | |
x="Fine-Tuning Strategy", | |
y="Number of Tunable Parameters", | |
title=BARPLOT_TITLE, | |
vertical=True, | |
height=300, | |
width=300, | |
interactive=True, | |
) | |
return result_pil_image, bar_plot | |
def predict( | |
unet_pretraining_type, | |
input_text, | |
guidance_scale=4, | |
num_inference_steps=75, | |
device="0", | |
OUTPUT_DIR="OUTPUT", | |
): | |
try: | |
BARPLOT_TITLE = "Tunable Parameters for {} Fine-Tuning".format(unet_pretraining_type) | |
NUM_TUNABLE_PARAMS = { | |
"full": 86, | |
"attention": 26.7, | |
"bias": 0.343, | |
"norm": 0.2, | |
"norm_bias_attention": 26.7, | |
"lorav2": 0.8, | |
"svdiff": 0.222, | |
"difffit": 0.581, | |
} | |
cuda_device = f"cuda:{device}" if torch.cuda.is_available() else "cpu" | |
print("Loading Pipeline for {} Fine-Tuning".format(unet_pretraining_type)) | |
sd_pipeline = loadSDModel( | |
unet_pretraining_type=unet_pretraining_type, | |
cuda_device=cuda_device, | |
) | |
sd_pipeline.to(cuda_device) | |
result_image = sd_pipeline( | |
prompt=input_text, | |
height=224, | |
width=224, | |
guidance_scale=guidance_scale, | |
num_inference_steps=num_inference_steps, | |
) | |
result_pil_image = result_image["images"][0] | |
# Create a Bar Plot displaying the number of tunable parameters for the selected PEFT Type | |
df = pd.DataFrame( | |
{ | |
"Fine-Tuning Strategy": list(NUM_TUNABLE_PARAMS.keys()), | |
"Number of Tunable Parameters": list(NUM_TUNABLE_PARAMS.values()), | |
} | |
) | |
print(df) | |
df = df[df["Fine-Tuning Strategy"].isin(["full", unet_pretraining_type])].reset_index( | |
drop=True | |
) | |
bar_plot = gr.BarPlot( | |
value=df, | |
x="Fine-Tuning Strategy", | |
y="Number of Tunable Parameters", | |
title=BARPLOT_TITLE, | |
vertical=True, | |
height=300, | |
width=300, | |
interactive=True, | |
) | |
return result_pil_image, bar_plot | |
except: | |
return _predict_using_default_params() | |
# Create a Gradio interface | |
""" | |
Input Parameters: | |
1. PEFT Type: (Dropdown) The type of PEFT to use for generating the X-ray | |
2. Input Text: (Textbox) The text prompt to use for generating the X-ray | |
3. Guidance Scale: (Slider) The guidance scale to use for generating the X-ray | |
4. Num Inference Steps: (Slider) The number of inference steps to use for generating the X-ray | |
Output Parameters: | |
1. Generated X-ray Image: (Image) The generated X-ray image | |
2. Number of Tunable Parameters: (Bar Plot) The number of tunable parameters for the selected PEFT Type | |
""" | |
iface = gr.Interface( | |
fn=predict, | |
inputs=[ | |
gr.Dropdown( | |
["full", "difffit", "norm", "bias", "attention", "norm_bias_attention"], | |
value="full", | |
label="PEFT Type", | |
), | |
gr.Dropdown( | |
EXAMPLE_TEXT_PROMPTS, info=INFO_ABOUT_TEXT_PROMPT, label="Input Text", value=EXAMPLE_TEXT_PROMPTS[0] | |
), | |
gr.Slider( | |
minimum=1, | |
maximum=10, | |
value=4, | |
step=1, | |
info=INFO_ABOUT_GUIDANCE_SCALE, | |
label="Guidance Scale", | |
), | |
gr.Slider( | |
minimum=1, | |
maximum=100, | |
value=75, | |
step=1, | |
info=INFO_ABOUT_INFERENCE_STEPS, | |
label="Num Inference Steps", | |
), | |
], | |
outputs=[gr.Image(type="pil"), gr.BarPlot()], | |
live=True, | |
analytics_enabled=False, | |
title=TITLE, | |
) | |
# Launch the Gradio interface | |
iface.launch(share=True) | |