gaur3009 commited on
Commit
9a8c17e
·
verified ·
1 Parent(s): 108897b

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +101 -0
app.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from diffusers import StableDiffusionXLPipeline, EulerDiscreteScheduler
4
+ from huggingface_hub import hf_hub_download
5
+ from safetensors.torch import load_file
6
+ import spaces
7
+ import os
8
+ from PIL import Image
9
+
10
+ SAFETY_CHECKER = os.environ.get("SAFETY_CHECKER", "0") == "1"
11
+
12
+ # Constants
13
+ base = "stabilityai/stable-diffusion-xl-base-1.0"
14
+ repo = "ByteDance/SDXL-Lightning"
15
+ checkpoints = {
16
+ "1-Step" : ["sdxl_lightning_1step_unet_x0.safetensors", 1],
17
+ "2-Step" : ["sdxl_lightning_2step_unet.safetensors", 2],
18
+ "4-Step" : ["sdxl_lightning_4step_unet.safetensors", 4],
19
+ "8-Step" : ["sdxl_lightning_8step_unet.safetensors", 8],
20
+ }
21
+ loaded = None
22
+
23
+
24
+ # Ensure model and scheduler are initialized in GPU-enabled function
25
+ if torch.cuda.is_available():
26
+ pipe = StableDiffusionXLPipeline.from_pretrained(base, torch_dtype=torch.float16, variant="fp16").to("cuda")
27
+
28
+ if SAFETY_CHECKER:
29
+ from safety_checker import StableDiffusionSafetyChecker
30
+ from transformers import CLIPFeatureExtractor
31
+
32
+ safety_checker = StableDiffusionSafetyChecker.from_pretrained(
33
+ "CompVis/stable-diffusion-safety-checker"
34
+ ).to("cuda")
35
+ feature_extractor = CLIPFeatureExtractor.from_pretrained(
36
+ "openai/clip-vit-base-patch32"
37
+ )
38
+
39
+ def check_nsfw_images(
40
+ images: list[Image.Image],
41
+ ) -> tuple[list[Image.Image], list[bool]]:
42
+ safety_checker_input = feature_extractor(images, return_tensors="pt").to("cuda")
43
+ has_nsfw_concepts = safety_checker(
44
+ images=[images],
45
+ clip_input=safety_checker_input.pixel_values.to("cuda")
46
+ )
47
+
48
+ return images, has_nsfw_concepts
49
+
50
+ # Function
51
+ @spaces.GPU(enable_queue=True)
52
+ def generate_image(prompt, ckpt):
53
+ global loaded
54
+ print(prompt, ckpt)
55
+
56
+ checkpoint = checkpoints[ckpt][0]
57
+ num_inference_steps = checkpoints[ckpt][1]
58
+
59
+ if loaded != num_inference_steps:
60
+ pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing", prediction_type="sample" if num_inference_steps==1 else "epsilon")
61
+ pipe.unet.load_state_dict(load_file(hf_hub_download(repo, checkpoint), device="cuda"))
62
+ loaded = num_inference_steps
63
+
64
+ results = pipe(prompt, num_inference_steps=num_inference_steps, guidance_scale=0)
65
+
66
+ if SAFETY_CHECKER:
67
+ images, has_nsfw_concepts = check_nsfw_images(results.images)
68
+ if any(has_nsfw_concepts):
69
+ gr.Warning("NSFW content detected.")
70
+ return Image.new("RGB", (512, 512))
71
+ return images[0]
72
+ return results.images[0]
73
+
74
+
75
+
76
+ # Gradio Interface
77
+ description = """
78
+ This demo utilizes the SDXL-Lightning model by ByteDance, which is a lightning-fast text-to-image generative model capable of producing high-quality images in 4 steps.
79
+ As a community effort, this demo was put together by AngryPenguin. Link to model: https://huggingface.co/ByteDance/SDXL-Lightning
80
+ """
81
+
82
+ with gr.Blocks(css="style.css") as demo:
83
+ gr.HTML("<h1><center>Text-to-Image with SDXL-Lightning ⚡</center></h1>")
84
+ gr.Markdown(description)
85
+ with gr.Group():
86
+ with gr.Row():
87
+ prompt = gr.Textbox(label='Enter your prompt (English)', scale=8)
88
+ ckpt = gr.Dropdown(label='Select inference steps',choices=['1-Step', '2-Step', '4-Step', '8-Step'], value='4-Step', interactive=True)
89
+ submit = gr.Button(scale=1, variant='primary')
90
+ img = gr.Image(label='SDXL-Lightning Generated Image')
91
+
92
+ prompt.submit(fn=generate_image,
93
+ inputs=[prompt, ckpt],
94
+ outputs=img,
95
+ )
96
+ submit.click(fn=generate_image,
97
+ inputs=[prompt, ckpt],
98
+ outputs=img,
99
+ )
100
+
101
+ demo.queue().launch()