Last commit not found
import spaces | |
import gradio as gr | |
import torch | |
from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel | |
from tdd_scheduler import TDDScheduler | |
from safetensors.torch import load_file | |
from PIL import Image | |
import transformers | |
transformers.utils.move_cache() | |
SAFETY_CHECKER = True | |
loaded_acc = None | |
device = "cuda" | |
ACC_lora={ | |
"TDD":"sdxl_tdd_wo_adv_lora.safetensors", | |
"TDD_adv":"sdxl_tdd_lora_weights.safetensors", | |
} | |
if torch.cuda.is_available(): | |
base1 = UNet2DConditionModel.from_pretrained( | |
"stabilityai/stable-diffusion-xl-base-1.0", subfolder="unet", torch_dtype=torch.float16 | |
).to(device) | |
base2 = UNet2DConditionModel.from_pretrained( | |
"frankjoshua/realvisxlV40_v40Bakedvae", subfolder="unet", torch_dtype=torch.float16 | |
).to(device) | |
pipe_sdxl = StableDiffusionXLPipeline.from_pretrained( | |
"stabilityai/stable-diffusion-xl-base-1.0", | |
unet=base1, | |
torch_dtype=torch.float16, | |
variant="fp16", | |
).to(device) | |
pipe_sdxl.load_lora_weights("RED-AIGC/TDD", weight_name=ACC_lora["TDD"], adapter_name="TDD") | |
pipe_sdxl.load_lora_weights("RED-AIGC/TDD", weight_name=ACC_lora["TDD_adv"], adapter_name="TDD_adv") | |
pipe_sdxl.scheduler = TDDScheduler.from_config(pipe_sdxl.scheduler.config) | |
pipe_sdxl_real = StableDiffusionXLPipeline.from_pretrained( | |
"stabilityai/stable-diffusion-xl-base-1.0", | |
unet=base2, | |
torch_dtype=torch.float16, | |
variant="fp16", | |
).to(device) | |
pipe_sdxl_real.load_lora_weights("RED-AIGC/TDD", weight_name=ACC_lora["TDD"], adapter_name="TDD") | |
pipe_sdxl_real.load_lora_weights("RED-AIGC/TDD", weight_name=ACC_lora["TDD_adv"], adapter_name="TDD_adv") | |
pipe_sdxl_real.scheduler = TDDScheduler.from_config(pipe_sdxl.scheduler.config) | |
def update_base_model(ckpt): | |
if torch.cuda.is_available(): | |
pipe_sdxl = StableDiffusionXLPipeline.from_pretrained( | |
"stabilityai/stable-diffusion-xl-base-1.0", | |
torch_dtype=torch.float16, | |
variant="fp16", | |
).to(device) | |
return pipe_sdxl | |
if SAFETY_CHECKER: | |
from safety_checker import StableDiffusionSafetyChecker | |
from transformers import CLIPFeatureExtractor | |
safety_checker = StableDiffusionSafetyChecker.from_pretrained( | |
"CompVis/stable-diffusion-safety-checker" | |
).to(device) | |
feature_extractor = CLIPFeatureExtractor.from_pretrained( | |
"openai/clip-vit-base-patch32" | |
) | |
def check_nsfw_images( | |
images: list[Image.Image], | |
) -> tuple[list[Image.Image], list[bool]]: | |
safety_checker_input = feature_extractor(images, return_tensors="pt").to(device) | |
has_nsfw_concepts = safety_checker( | |
images=[images], clip_input=safety_checker_input.pixel_values.to(device) | |
) | |
return images, has_nsfw_concepts | |
def generate_image( | |
prompt, | |
negative_prompt, | |
ckpt, | |
acc, | |
num_inference_steps, | |
guidance_scale, | |
eta, | |
seed, | |
progress=gr.Progress(track_tqdm=True), | |
): | |
global loaded_acc | |
#pipe = pipe_sdxl #if mode == "sdxl" else pipe_sd15 | |
if ckpt == "Real": | |
pipe = pipe_sdxl_real | |
else: | |
pipe = pipe_sdxl | |
if loaded_acc != acc: | |
#pipe.load_lora_weights(ACC_lora[acc], adapter_name=acc) | |
pipe.set_adapters([acc], adapter_weights=[1.0]) | |
print(pipe.get_active_adapters()) | |
loaded_acc = acc | |
results = pipe( | |
prompt=prompt, | |
negative_prompt=negative_prompt, | |
num_inference_steps=num_inference_steps, | |
guidance_scale=guidance_scale, | |
eta=eta, | |
generator=torch.Generator(device=pipe.device).manual_seed(seed), | |
) | |
if SAFETY_CHECKER: | |
images, has_nsfw_concepts = check_nsfw_images(results.images) | |
if any(has_nsfw_concepts): | |
gr.Warning("NSFW content detected.") | |
return Image.new("RGB", (512, 512)) | |
return images[0] | |
return results.images[0] | |
css = """ | |
h1 { | |
text-align: center; | |
display:block; | |
} | |
.gradio-container { | |
max-width: 70.5rem !important; | |
} | |
""" | |
with gr.Blocks(css=css) as demo: | |
gr.Markdown( | |
""" | |
# ✨Target-Driven Distillation✨ | |
Target-Driven Distillation (TDD) is a state-of-the-art consistency distillation model that largely accelerates the inference processes of diffusion models. Using its delicate strategies of *target timestep selection* and *decoupled guidance*, models distilled by TDD can generated highly detailed images with only a few steps. | |
[](https://arxiv.org) [](https://huggingface.co/RedAIGC/TDD) | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(scale=1): | |
with gr.Group(): | |
with gr.Row(): | |
prompt = gr.Textbox(label="Prompt") | |
with gr.Row(): | |
negative_prompt = gr.Textbox(label="Negative Prompt") | |
with gr.Row(): | |
steps = gr.Slider( | |
label="Sampling Steps", | |
minimum=4, | |
maximum=8, | |
step=1, | |
value=4, | |
) | |
with gr.Row(): | |
guidance_scale = gr.Slider( | |
label="CFG Scale", | |
minimum=1, | |
maximum=4, | |
step=0.1, | |
value=2.0, | |
) | |
with gr.Row(): | |
eta = gr.Slider( | |
label="eta", | |
minimum=0, | |
maximum=0.3, | |
step=0.01, | |
value=0.2, | |
) | |
with gr.Row(): | |
seed = gr.Number(label="Seed", value=-1) | |
with gr.Row(): | |
ckpt = gr.Dropdown( | |
label="Base Model", | |
choices=["SDXL-1.0", "Real"], | |
value="SDXL-1.0", | |
) | |
acc = gr.Dropdown( | |
label="Accelerate Lora", | |
choices=["TDD", "TDD_adv"], | |
value="TDD_adv", | |
) | |
with gr.Column(scale=1): | |
with gr.Group(): | |
img = gr.Image(label="TDD Image", value="cat.png") | |
submit_sdxl = gr.Button("Run on SDXL") | |
gr.Examples( | |
examples=[ | |
["A photo of a cat made of water.", "", "SDXL-1.0", "TDD_adv", 4, 1.7, 0.2, 546237], | |
["Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", "NSFW, low quality, low resolution, normal quality", "Real", "TDD_adv", 7, 2.0, 0.3, 68436] | |
["panda mad scientist mixing sparkling chemicals", "", "Real", "TDD_adv", 5, 1.7, 0.2, 3853], | |
["a dog in snow, 8k", "", "SDXL-1.0", "TDD", 6, 1.3, 0.2, 12138] | |
], | |
inputs=[prompt, negative_prompt, ckpt, acc, steps, guidance_scale, eta, seed], | |
outputs=[img], | |
fn=generate_image, | |
cache_examples="lazy", | |
) | |
gr.on( | |
fn=generate_image, | |
triggers=[ckpt.change, prompt.submit, submit_sdxl.click], | |
inputs=[prompt, negative_prompt, ckpt, acc, steps, guidance_scale, eta, seed], | |
outputs=[img], | |
) | |
demo.queue(api_open=False).launch(show_api=False) |