Spaces:
Sleeping
Sleeping
File size: 4,495 Bytes
b0f3145 8425644 6619b46 b0f3145 84c6537 88dc089 b0f3145 8425644 b0f3145 23f3ac6 59f3984 75859e2 7c672bb b0f3145 59f3984 795174e e08daf1 59f3984 e08daf1 5ecb4a9 59f3984 4f9929e 1db955a 5db2f57 5ecb4a9 87670df 6619b46 5ecb4a9 6619b46 1ffdbc2 6619b46 1ffdbc2 0b120d5 6619b46 0b120d5 6619b46 b0f3145 7c672bb 7e06c4d 88dc089 7e06c4d b0f3145 0ce0e61 b0f3145 a364bc6 b0f3145 75859e2 b0f3145 75859e2 b0f3145 75859e2 bef93f3 75859e2 bef93f3 0c6962b bef93f3 75859e2 d20698c 75859e2 bef93f3 bbd2321 bef93f3 b0f3145 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 |
import gradio as gr
import torch
import spaces
from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel, EulerDiscreteScheduler
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
from diffusers.image_processor import VaeImageProcessor
from transformers import CLIPImageProcessor
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
device = "cuda"
dtype = torch.float16
base = "stabilityai/stable-diffusion-xl-base-1.0"
repo = "ByteDance/SDXL-Lightning"
opts = {
"1 Step" : ("sdxl_lightning_1step_unet_x0.safetensors", 1),
"2 Steps" : ("sdxl_lightning_2step_unet.safetensors", 2),
"4 Steps" : ("sdxl_lightning_4step_unet.safetensors", 4),
"8 Steps" : ("sdxl_lightning_8step_unet.safetensors", 8),
}
# Inference function.
@spaces.GPU()
def generate(prompt, option, progress=gr.Progress()):
print(prompt, option)
ckpt, step = opts[option]
progress(0, desc="Initializing the model")
# Main pipeline.
unet = UNet2DConditionModel.from_config(base, subfolder="unet")
pipe = StableDiffusionXLPipeline.from_pretrained(base, unet=unet, torch_dtype=dtype, variant="fp16").to(device, dtype)
pipe.unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device=device))
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing", prediction_type="sample" if step == 1 else "epsilon")
# Safety checker.
safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker").to(device, dtype)
feature_extractor=CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32")
image_processor = VaeImageProcessor(vae_scale_factor=8)
def inference_callback(p, i, t, kwargs):
progress((i+1, step))
return kwargs
# Inference loop.
progress((0, step))
results = pipe(prompt, num_inference_steps=step, guidance_scale=0, callback_on_step_end=inference_callback, output_type="pt")
# Safety check.
feature_extractor_input = image_processor.postprocess(results.images, output_type="pil")
safety_checker_input = feature_extractor(feature_extractor_input, return_tensors="pt")
pixel_values = safety_checker_input.pixel_values.to(device, dtype)
images, has_nsfw_concept = safety_checker(
images=results.images, clip_input=pixel_values
)
if has_nsfw_concept[0]:
print(f"Safety checker triggered on prompt: {prompt}")
return images[0]
with gr.Blocks(css="style.css") as demo:
gr.HTML(
"<h1><center>SDXL-Lightning</center></h1>" +
"<p><center>Lightning-fast text-to-image generation</center></p>" +
"<p><center><a href='https://huggingface.co/ByteDance/SDXL-Lightning'>https://huggingface.co/ByteDance/SDXL-Lightning</a></center></p>"
)
with gr.Row():
prompt = gr.Textbox(
label="Text prompt",
scale=8
)
option = gr.Dropdown(
label="Inference steps",
choices=["1 Step", "2 Steps", "4 Steps", "8 Steps"],
value="4 Steps",
interactive=True
)
submit = gr.Button(
scale=1,
variant="primary"
)
img = gr.Image(label="SDXL-Lightning Generated Image")
prompt.submit(
fn=generate,
inputs=[prompt, option],
outputs=img,
)
submit.click(
fn=generate,
inputs=[prompt, option],
outputs=img,
)
gr.Examples(
fn=generate,
examples=[
["An owl perches quietly on a twisted branch deep within an ancient forest.", "1 Step"],
["A lion in the galaxy, octane render", "2 Steps"],
["A dolphin leaps through the waves, set against a backdrop of bright blues and teal hues.", "2 Steps"],
["A girl smiling", "4 Steps"],
["An astronaut riding a horse", "4 Steps"],
["A fish on a bicycle, colorful art", "4 Steps"],
["A close-up of an Asian lady with sunglasses.", "4 Steps"],
["Rabbit portrait in a forest, fantasy", "4 Steps"],
["A panda swimming", "4 Steps"],
["Man portrait, ethereal", "8 Steps"],
],
inputs=[prompt, option],
outputs=img,
cache_examples=False,
)
gr.HTML(
"<p><small><center>This demo is built together by the community</center></small></p>"
)
demo.queue().launch() |