Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,513 Bytes
b0f3145 187715d b0f3145 187715d b0f3145 e9034bb 88dc089 187715d b0f3145 187715d b0f3145 f151301 e08daf1 187715d 59f3984 187715d 59f3984 187715d 0b120d5 b0f3145 187715d 0ce0e61 b0f3145 187715d 75859e2 187715d f151301 2a8c2d1 bef93f3 187715d f151301 187715d f151301 187715d bb36173 187715d df63923 187715d 2a8c2d1 e9028a2 187715d d099fe7 187715d d099fe7 187715d b0f3145 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 |
import gradio as gr
import torch
from diffusers import StableDiffusionXLPipeline, EulerDiscreteScheduler
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
import spaces
from PIL import Image
SAFETY_CHECKER = True
# Constants
base = "stabilityai/stable-diffusion-xl-base-1.0"
repo = "ByteDance/SDXL-Lightning"
checkpoints = {
"1-Step" : ["sdxl_lightning_1step_unet_x0.safetensors", 1],
"2-Step" : ["sdxl_lightning_2step_unet.safetensors", 2],
"4-Step" : ["sdxl_lightning_4step_unet.safetensors", 4],
"8-Step" : ["sdxl_lightning_8step_unet.safetensors", 8],
}
loaded = None
# Ensure model and scheduler are initialized in GPU-enabled function
if torch.cuda.is_available():
pipe = StableDiffusionXLPipeline.from_pretrained(base, torch_dtype=torch.float16, variant="fp16").to("cuda")
if SAFETY_CHECKER:
from safety_checker import StableDiffusionSafetyChecker
from transformers import CLIPFeatureExtractor
safety_checker = StableDiffusionSafetyChecker.from_pretrained(
"CompVis/stable-diffusion-safety-checker"
).to("cuda")
feature_extractor = CLIPFeatureExtractor.from_pretrained(
"openai/clip-vit-base-patch32"
)
def check_nsfw_images(
images: list[Image.Image],
) -> tuple[list[Image.Image], list[bool]]:
safety_checker_input = feature_extractor(images, return_tensors="pt").to("cuda")
has_nsfw_concepts = safety_checker(
images=[images],
clip_input=safety_checker_input.pixel_values.to("cuda")
)
return images, has_nsfw_concepts
# Function
@spaces.GPU(enable_queue=True)
def generate_image(prompt, ckpt):
global loaded
print(prompt, ckpt)
checkpoint = checkpoints[ckpt][0]
num_inference_steps = checkpoints[ckpt][1]
if loaded != num_inference_steps:
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing", prediction_type="sample" if num_inference_steps==1 else "epsilon")
pipe.unet.load_state_dict(load_file(hf_hub_download(repo, checkpoint), device="cuda"))
loaded = num_inference_steps
results = pipe(prompt, num_inference_steps=num_inference_steps, guidance_scale=0)
if SAFETY_CHECKER:
images, has_nsfw_concepts = check_nsfw_images(results.images)
if any(has_nsfw_concepts):
gr.Warning("NSFW content detected.")
return images[0]
return images[0]
return results.images[0]
# Gradio Interface
with gr.Blocks(css="style.css") as demo:
gr.HTML("<h1><center>SDXL-Lightning ⚡</center></h1>")
gr.HTML("<p><center>Lightning-fast text-to-image generation</center></p><p><center><a href='https://huggingface.co/ByteDance/SDXL-Lightning'>https://huggingface.co/ByteDance/SDXL-Lightning</a></center></p>")
with gr.Group():
with gr.Row():
prompt = gr.Textbox(label='Enter your prompt (English)', scale=8)
ckpt = gr.Dropdown(label='Select inference steps',choices=['1-Step', '2-Step', '4-Step', '8-Step'], value='4-Step', interactive=True)
submit = gr.Button(scale=1, variant='primary')
img = gr.Image(label='SDXL-Lightning Generated Image')
prompt.submit(fn=generate_image,
inputs=[prompt, ckpt],
outputs=img,
)
submit.click(fn=generate_image,
inputs=[prompt, ckpt],
outputs=img,
)
demo.queue().launch() |