awacke1 commited on
Commit
0f062e5
·
verified ·
1 Parent(s): 97e016b

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +201 -0
app.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from diffusers import StableDiffusionXLPipeline, StableDiffusionPipeline, LCMScheduler
4
+ from diffusers.schedulers import TCDScheduler
5
+ import spaces
6
+ from PIL import Image
7
+ import os
8
+ import re
9
+ from datetime import datetime
10
+ import random
11
+ import glob
12
+
13
+ SAFETY_CHECKER = True
14
+
15
+ checkpoints = {
16
+ "2-Step": ["pcm_{}_smallcfg_2step_converted.safetensors", 2, 0.0],
17
+ "4-Step": ["pcm_{}_smallcfg_4step_converted.safetensors", 4, 0.0],
18
+ "8-Step": ["pcm_{}_smallcfg_8step_converted.safetensors", 8, 0.0],
19
+ "16-Step": ["pcm_{}_smallcfg_16step_converted.safetensors", 16, 0.0],
20
+ "Normal CFG 4-Step": ["pcm_{}_normalcfg_4step_converted.safetensors", 4, 7.5],
21
+ "Normal CFG 8-Step": ["pcm_{}_normalcfg_8step_converted.safetensors", 8, 7.5],
22
+ "Normal CFG 16-Step": ["pcm_{}_normalcfg_16step_converted.safetensors", 16, 7.5],
23
+ "LCM-Like LoRA": ["pcm_{}_lcmlike_lora_converted.safetensors", 4, 0.0],
24
+ }
25
+
26
+ loaded = None
27
+
28
+ if torch.cuda.is_available():
29
+ pipe_sdxl = StableDiffusionXLPipeline.from_pretrained(
30
+ "stabilityai/stable-diffusion-xl-base-1.0",
31
+ torch_dtype=torch.float16,
32
+ variant="fp16",
33
+ ).to("cuda")
34
+ pipe_sd15 = StableDiffusionPipeline.from_pretrained(
35
+ "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, variant="fp16"
36
+ ).to("cuda")
37
+
38
+ if SAFETY_CHECKER:
39
+ from safety_checker import StableDiffusionSafetyChecker
40
+ from transformers import CLIPFeatureExtractor
41
+
42
+ safety_checker = StableDiffusionSafetyChecker.from_pretrained(
43
+ "CompVis/stable-diffusion-safety-checker"
44
+ ).to("cuda")
45
+ feature_extractor = CLIPFeatureExtractor.from_pretrained(
46
+ "openai/clip-vit-base-patch32"
47
+ )
48
+
49
+ def check_nsfw_images(images: list[Image.Image]) -> tuple[list[Image.Image], list[bool]]:
50
+ safety_checker_input = feature_extractor(images, return_tensors="pt").to("cuda")
51
+ has_nsfw_concepts = safety_checker(
52
+ images=[images], clip_input=safety_checker_input.pixel_values.to("cuda")
53
+ )
54
+ return images, has_nsfw_concepts
55
+
56
+ def save_image(image: Image.Image, prompt: str) -> str:
57
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
58
+ clean_prompt = re.sub(r'[^\w\-_\. ]', '_', prompt)[:50]
59
+ filename = f"{timestamp}_{clean_prompt}.png"
60
+ image.save(filename)
61
+ return filename
62
+
63
+ def get_image_gallery():
64
+ image_files = glob.glob("*.png")
65
+ return sorted([(file, file) for file in image_files], key=lambda x: os.path.getmtime(x[0]), reverse=True)
66
+
67
+ @spaces.GPU(enable_queue=True)
68
+ def generate_image(
69
+ prompt,
70
+ ckpt,
71
+ num_inference_steps,
72
+ progress=gr.Progress(track_tqdm=True),
73
+ mode="sdxl",
74
+ ):
75
+ global loaded
76
+ checkpoint = checkpoints[ckpt][0].format(mode)
77
+ guidance_scale = checkpoints[ckpt][2]
78
+ pipe = pipe_sdxl if mode == "sdxl" else pipe_sd15
79
+
80
+ if loaded != (ckpt + mode):
81
+ pipe.load_lora_weights(
82
+ "wangfuyun/PCM_Weights", weight_name=checkpoint, subfolder=mode
83
+ )
84
+ loaded = ckpt + mode
85
+
86
+ if ckpt == "LCM-Like LoRA":
87
+ pipe.scheduler = LCMScheduler()
88
+ else:
89
+ pipe.scheduler = TCDScheduler(
90
+ num_train_timesteps=1000,
91
+ beta_start=0.00085,
92
+ beta_end=0.012,
93
+ beta_schedule="scaled_linear",
94
+ timestep_spacing="trailing",
95
+ )
96
+
97
+ results = pipe(
98
+ prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale
99
+ )
100
+
101
+ if SAFETY_CHECKER:
102
+ images, has_nsfw_concepts = check_nsfw_images(results.images)
103
+ if any(has_nsfw_concepts):
104
+ gr.Warning("NSFW content detected.")
105
+ return Image.new("RGB", (512, 512)), get_image_gallery()
106
+ filename = save_image(images[0], prompt)
107
+ return images[0], get_image_gallery()
108
+ filename = save_image(results.images[0], prompt)
109
+ return results.images[0], get_image_gallery()
110
+
111
+ def update_steps(ckpt):
112
+ num_inference_steps = checkpoints[ckpt][1]
113
+ if ckpt == "LCM-Like LoRA":
114
+ return gr.update(interactive=True, value=num_inference_steps)
115
+ return gr.update(interactive=False, value=num_inference_steps)
116
+
117
+ css = """
118
+ .gradio-container {
119
+ max-width: 60rem !important;
120
+ }
121
+ """
122
+
123
+ art_styles = ['Impressionist', 'Cubist', 'Surrealist', 'Abstract Expressionist', 'Pop Art', 'Minimalist', 'Baroque', 'Art Nouveau', 'Pointillist', 'Fauvism']
124
+
125
+ examples = [
126
+ f"{random.choice(art_styles)} painting of a majestic lighthouse on a rocky coast. Use bold brushstrokes and a vibrant color palette to capture the interplay of light and shadow as the lighthouse beam cuts through a stormy night sky.",
127
+ f"{random.choice(art_styles)} still life featuring a pair of vintage eyeglasses. Focus on the intricate details of the frames and lenses, using a warm color scheme to evoke a sense of nostalgia and wisdom.",
128
+ f"{random.choice(art_styles)} depiction of a rustic wooden stool in a sunlit artist's studio. Emphasize the texture of the wood and the interplay of light and shadow, using a mix of earthy tones and highlights.",
129
+ f"{random.choice(art_styles)} scene viewed through an ornate window frame. Contrast the intricate details of the window with a dreamy, soft-focus landscape beyond, using a palette that transitions from cool interior tones to warm exterior hues.",
130
+ f"{random.choice(art_styles)} close-up study of interlaced fingers. Use a monochromatic color scheme to emphasize the form and texture of the hands, with dramatic lighting to create depth and emotion.",
131
+ f"{random.choice(art_styles)} composition featuring a set of dice in motion. Capture the energy and randomness of the throw, using a dynamic color palette and blurred lines to convey movement.",
132
+ f"{random.choice(art_styles)} interpretation of heaven. Create an ethereal atmosphere with soft, billowing clouds and radiant light, using a palette of celestial blues, golds, and whites.",
133
+ f"{random.choice(art_styles)} portrayal of an ancient, mystical gate. Combine architectural details with elements of fantasy, using a rich, jewel-toned palette to create an air of mystery and magic.",
134
+ f"{random.choice(art_styles)} portrait of a curious cat. Focus on capturing the feline's expressive eyes and sleek form, using a mix of bold and subtle colors to bring out the cat's personality.",
135
+ f"{random.choice(art_styles)} abstract representation of toes in sand. Use textured brushstrokes to convey the feeling of warm sand, with a palette inspired by a sun-drenched beach."
136
+ ]
137
+
138
+ with gr.Blocks(css=css) as demo:
139
+ gr.Markdown(
140
+ """
141
+ # Phased Consistency Model
142
+
143
+ Phased Consistency Model (PCM) is an image generation technique that addresses the limitations of the Latent Consistency Model (LCM) in high-resolution and text-conditioned image generation.
144
+ PCM outperforms LCM across various generation settings and achieves state-of-the-art results in both image and video generation.
145
+
146
+ [[paper](https://huggingface.co/papers/2405.18407)] [[arXiv](https://arxiv.org/abs/2405.18407)] [[code](https://github.com/G-U-N/Phased-Consistency-Model)] [[project page](https://g-u-n.github.io/projects/pcm)]
147
+ """
148
+ )
149
+ with gr.Group():
150
+ with gr.Row():
151
+ prompt = gr.Textbox(label="Prompt", scale=8)
152
+ ckpt = gr.Dropdown(
153
+ label="Select inference steps",
154
+ choices=list(checkpoints.keys()),
155
+ value="4-Step",
156
+ )
157
+ steps = gr.Slider(
158
+ label="Number of Inference Steps",
159
+ minimum=1,
160
+ maximum=20,
161
+ step=1,
162
+ value=4,
163
+ interactive=False,
164
+ )
165
+ ckpt.change(
166
+ fn=update_steps,
167
+ inputs=[ckpt],
168
+ outputs=[steps],
169
+ queue=False,
170
+ show_progress=False,
171
+ )
172
+
173
+ submit_sdxl = gr.Button("Run on SDXL", scale=1)
174
+ submit_sd15 = gr.Button("Run on SD15", scale=1)
175
+
176
+ img = gr.Image(label="PCM Image")
177
+ gallery = gr.Gallery(label="Generated Images", show_label=True, columns=4, height="auto")
178
+ gr.Examples(
179
+ examples=examples,
180
+ inputs=[prompt, ckpt, steps],
181
+ outputs=[img, gallery],
182
+ fn=generate_image,
183
+ cache_examples=True,
184
+ )
185
+
186
+ gr.on(
187
+ fn=generate_image,
188
+ triggers=[ckpt.change, prompt.submit, submit_sdxl.click],
189
+ inputs=[prompt, ckpt, steps],
190
+ outputs=[img, gallery],
191
+ )
192
+ gr.on(
193
+ fn=lambda *args: generate_image(*args, mode="sd15"),
194
+ triggers=[submit_sd15.click],
195
+ inputs=[prompt, ckpt, steps],
196
+ outputs=[img, gallery],
197
+ )
198
+
199
+ demo.load(fn=get_image_gallery, outputs=gallery)
200
+
201
+ demo.queue(api_open=False).launch(show_api=False)