TastyRice commited on
Commit
5ab0963
1 Parent(s): 267d317

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +395 -0
app.py ADDED
@@ -0,0 +1,395 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gc
3
+ import gradio as gr
4
+ import numpy as np
5
+ import torch
6
+ import json
7
+ import spaces
8
+ import config
9
+ import utils
10
+ import logging
11
+ from PIL import Image, PngImagePlugin
12
+ from datetime import datetime
13
+ from diffusers.models import AutoencoderKL
14
+ from diffusers import StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline
15
+
16
+ logging.basicConfig(level=logging.INFO)
17
+ logger = logging.getLogger(__name__)
18
+
19
+ DESCRIPTION = "Magic_on_paper"
20
+ if not torch.cuda.is_available():
21
+ DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU. </p>"
22
+ IS_COLAB = utils.is_google_colab() or os.getenv("IS_COLAB") == "1"
23
+ HF_TOKEN = os.getenv("HF_TOKEN")
24
+ CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES") == "1"
25
+ MIN_IMAGE_SIZE = int(os.getenv("MIN_IMAGE_SIZE", "512"))
26
+ MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "2048"))
27
+ USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE") == "1"
28
+ ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD") == "1"
29
+ OUTPUT_DIR = os.getenv("OUTPUT_DIR", "./outputs")
30
+
31
+ MODEL = os.getenv(
32
+ "MODEL",
33
+ "https://huggingface.co/Tasty-Rice/Magic_on_paper/blob/main/Magic_on_paper-SDXL-v3.safetensors",
34
+ )
35
+
36
+ torch.backends.cudnn.deterministic = True
37
+ torch.backends.cudnn.benchmark = False
38
+
39
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
40
+
41
+
42
+ def load_pipeline(model_name):
43
+ vae = AutoencoderKL.from_pretrained(
44
+ "madebyollin/sdxl-vae-fp16-fix",
45
+ torch_dtype=torch.float16,
46
+ )
47
+ pipeline = (
48
+ StableDiffusionXLPipeline.from_single_file
49
+ if MODEL.endswith(".safetensors")
50
+ else StableDiffusionXLPipeline.from_pretrained
51
+ )
52
+
53
+ pipe = pipeline(
54
+ model_name,
55
+ vae=vae,
56
+ torch_dtype=torch.float16,
57
+ custom_pipeline="lpw_stable_diffusion_xl",
58
+ use_safetensors=True,
59
+ add_watermarker=False,
60
+ use_auth_token=HF_TOKEN,
61
+ )
62
+
63
+ pipe.to(device)
64
+ return pipe
65
+
66
+
67
+ @spaces.GPU
68
+ def generate(
69
+ prompt: str,
70
+ negative_prompt: str = "",
71
+ seed: int = 0,
72
+ custom_width: int = 1024,
73
+ custom_height: int = 1024,
74
+ guidance_scale: float = 7.0,
75
+ num_inference_steps: int = 28,
76
+ sampler: str = "Euler a",
77
+ aspect_ratio_selector: str = "768 x 1344",
78
+ style_selector: str = "(None)",
79
+ quality_selector: str = "Standard v3.1",
80
+ use_upscaler: bool = False,
81
+ upscaler_strength: float = 0.55,
82
+ upscale_by: float = 1.5,
83
+ add_quality_tags: bool = True,
84
+ progress=gr.Progress(track_tqdm=True),
85
+ ):
86
+ generator = utils.seed_everything(seed)
87
+
88
+ width, height = utils.aspect_ratio_handler(
89
+ aspect_ratio_selector,
90
+ custom_width,
91
+ custom_height,
92
+ )
93
+
94
+ prompt = utils.add_character(prompt, character_files)
95
+
96
+ prompt, negative_prompt = utils.preprocess_prompt(
97
+ quality_prompt, quality_selector, prompt, negative_prompt, add_quality_tags
98
+ )
99
+ prompt, negative_prompt = utils.preprocess_prompt(
100
+ styles, style_selector, prompt, negative_prompt
101
+ )
102
+
103
+ width, height = utils.preprocess_image_dimensions(width, height)
104
+
105
+ backup_scheduler = pipe.scheduler
106
+ pipe.scheduler = utils.get_scheduler(pipe.scheduler.config, sampler)
107
+
108
+ if use_upscaler:
109
+ upscaler_pipe = StableDiffusionXLImg2ImgPipeline(**pipe.components)
110
+ metadata = {
111
+ "prompt": prompt,
112
+ "negative_prompt": negative_prompt,
113
+ "resolution": f"{width} x {height}",
114
+ "guidance_scale": guidance_scale,
115
+ "num_inference_steps": num_inference_steps,
116
+ "seed": seed,
117
+ "sampler": sampler,
118
+ "sdxl_style": style_selector,
119
+ "add_quality_tags": add_quality_tags,
120
+ "quality_tags": quality_selector,
121
+ }
122
+
123
+ if use_upscaler:
124
+ new_width = int(width * upscale_by)
125
+ new_height = int(height * upscale_by)
126
+ metadata["use_upscaler"] = {
127
+ "upscale_method": "nearest-exact",
128
+ "upscaler_strength": upscaler_strength,
129
+ "upscale_by": upscale_by,
130
+ "new_resolution": f"{new_width} x {new_height}",
131
+ }
132
+ else:
133
+ metadata["use_upscaler"] = None
134
+ metadata["Model"] = {
135
+ "Model": DESCRIPTION,
136
+ "Model hash": "e3c47aedb0",
137
+ }
138
+
139
+ logger.info(json.dumps(metadata, indent=4))
140
+
141
+ try:
142
+ if use_upscaler:
143
+ latents = pipe(
144
+ prompt=prompt,
145
+ negative_prompt=negative_prompt,
146
+ width=width,
147
+ height=height,
148
+ guidance_scale=guidance_scale,
149
+ num_inference_steps=num_inference_steps,
150
+ generator=generator,
151
+ output_type="latent",
152
+ ).images
153
+ upscaled_latents = utils.upscale(latents, "nearest-exact", upscale_by)
154
+ images = upscaler_pipe(
155
+ prompt=prompt,
156
+ negative_prompt=negative_prompt,
157
+ image=upscaled_latents,
158
+ guidance_scale=guidance_scale,
159
+ num_inference_steps=num_inference_steps,
160
+ strength=upscaler_strength,
161
+ generator=generator,
162
+ output_type="pil",
163
+ ).images
164
+ else:
165
+ images = pipe(
166
+ prompt=prompt,
167
+ negative_prompt=negative_prompt,
168
+ width=width,
169
+ height=height,
170
+ guidance_scale=guidance_scale,
171
+ num_inference_steps=num_inference_steps,
172
+ generator=generator,
173
+ output_type="pil",
174
+ ).images
175
+
176
+ if images:
177
+ image_paths = [
178
+ utils.save_image(image, metadata, OUTPUT_DIR, IS_COLAB)
179
+ for image in images
180
+ ]
181
+
182
+ for image_path in image_paths:
183
+ logger.info(f"Image saved as {image_path} with metadata")
184
+
185
+ return image_paths, metadata
186
+ except Exception as e:
187
+ logger.exception(f"An error occurred: {e}")
188
+ raise
189
+ finally:
190
+ if use_upscaler:
191
+ del upscaler_pipe
192
+ pipe.scheduler = backup_scheduler
193
+ utils.free_memory()
194
+
195
+
196
+ if torch.cuda.is_available():
197
+ pipe = load_pipeline(MODEL)
198
+ logger.info("Loaded on Device!")
199
+ else:
200
+ pipe = None
201
+
202
+ styles = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in config.style_list}
203
+ quality_prompt = {
204
+ k["name"]: (k["prompt"], k["negative_prompt"]) for k in config.quality_prompt_list
205
+ }
206
+
207
+ character_files = utils.load_character_files("character")
208
+
209
+ with gr.Blocks(css="style.css", theme="NoCrypt/[email protected]") as demo:
210
+ title = gr.HTML(
211
+ f"""<h1><span>{DESCRIPTION}</span></h1>""",
212
+ elem_id="title",
213
+ )
214
+ gr.Markdown(
215
+ f"""Gradio demo for [Tasty-Rice/Magic_on_paper](https://huggingface.co/Tasty-Rice/Magic_on_paper)""",
216
+ elem_id="subtitle",
217
+ )
218
+ gr.DuplicateButton(
219
+ value="Duplicate Space for private use",
220
+ elem_id="duplicate-button",
221
+ visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1",
222
+ )
223
+ with gr.Row():
224
+ with gr.Column(scale=2):
225
+ with gr.Tab("Txt2img"):
226
+ with gr.Group():
227
+ prompt = gr.Text(
228
+ label="Prompt",
229
+ max_lines=5,
230
+ placeholder="Enter your prompt",
231
+ )
232
+ negative_prompt = gr.Text(
233
+ label="Negative Prompt",
234
+ max_lines=5,
235
+ placeholder="Enter a negative prompt",
236
+ )
237
+ with gr.Accordion(label="Quality Tags", open=True):
238
+ add_quality_tags = gr.Checkbox(
239
+ label="Add Quality Tags", value=True
240
+ )
241
+ quality_selector = gr.Dropdown(
242
+ label="Quality Tags Presets",
243
+ interactive=True,
244
+ choices=list(quality_prompt.keys()),
245
+ value="Standard v3.1",
246
+ )
247
+ with gr.Tab("Advanced Settings"):
248
+ with gr.Group():
249
+ style_selector = gr.Radio(
250
+ label="Style Preset",
251
+ container=True,
252
+ interactive=True,
253
+ choices=list(styles.keys()),
254
+ value="(None)",
255
+ )
256
+ with gr.Group():
257
+ aspect_ratio_selector = gr.Radio(
258
+ label="Aspect Ratio",
259
+ choices=config.aspect_ratios,
260
+ value="896 x 1152",
261
+ container=True,
262
+ )
263
+ with gr.Group(visible=False) as custom_resolution:
264
+ with gr.Row():
265
+ custom_width = gr.Slider(
266
+ label="Width",
267
+ minimum=MIN_IMAGE_SIZE,
268
+ maximum=MAX_IMAGE_SIZE,
269
+ step=8,
270
+ value=1024,
271
+ )
272
+ custom_height = gr.Slider(
273
+ label="Height",
274
+ minimum=MIN_IMAGE_SIZE,
275
+ maximum=MAX_IMAGE_SIZE,
276
+ step=8,
277
+ value=1024,
278
+ )
279
+ with gr.Group():
280
+ use_upscaler = gr.Checkbox(label="Use Upscaler", value=False)
281
+ with gr.Row() as upscaler_row:
282
+ upscaler_strength = gr.Slider(
283
+ label="Strength",
284
+ minimum=0,
285
+ maximum=1,
286
+ step=0.05,
287
+ value=0.55,
288
+ visible=False,
289
+ )
290
+ upscale_by = gr.Slider(
291
+ label="Upscale by",
292
+ minimum=1,
293
+ maximum=1.5,
294
+ step=0.1,
295
+ value=1.5,
296
+ visible=False,
297
+ )
298
+ with gr.Group():
299
+ sampler = gr.Dropdown(
300
+ label="Sampler",
301
+ choices=config.sampler_list,
302
+ interactive=True,
303
+ value="Euler a",
304
+ )
305
+ with gr.Group():
306
+ seed = gr.Slider(
307
+ label="Seed", minimum=0, maximum=utils.MAX_SEED, step=1, value=0
308
+ )
309
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
310
+ with gr.Group():
311
+ with gr.Row():
312
+ guidance_scale = gr.Slider(
313
+ label="Guidance scale",
314
+ minimum=1,
315
+ maximum=12,
316
+ step=0.1,
317
+ value=7.0,
318
+ )
319
+ num_inference_steps = gr.Slider(
320
+ label="Number of inference steps",
321
+ minimum=1,
322
+ maximum=50,
323
+ step=1,
324
+ value=28,
325
+ )
326
+ with gr.Column(scale=3):
327
+ with gr.Blocks():
328
+ run_button = gr.Button("Generate", variant="primary")
329
+ result = gr.Gallery(
330
+ label="Result",
331
+ columns=1,
332
+ height='100%',
333
+ preview=True,
334
+ show_label=False
335
+ )
336
+ with gr.Accordion(label="Generation Parameters", open=False):
337
+ gr_metadata = gr.JSON(label="metadata", show_label=False)
338
+ gr.Examples(
339
+ examples=config.examples,
340
+ inputs=prompt,
341
+ outputs=[result, gr_metadata],
342
+ fn=lambda *args, **kwargs: generate(*args, use_upscaler=True, **kwargs),
343
+ cache_examples=CACHE_EXAMPLES,
344
+ )
345
+ use_upscaler.change(
346
+ fn=lambda x: [gr.update(visible=x), gr.update(visible=x)],
347
+ inputs=use_upscaler,
348
+ outputs=[upscaler_strength, upscale_by],
349
+ queue=False,
350
+ api_name=False,
351
+ )
352
+ aspect_ratio_selector.change(
353
+ fn=lambda x: gr.update(visible=x == "Custom"),
354
+ inputs=aspect_ratio_selector,
355
+ outputs=custom_resolution,
356
+ queue=False,
357
+ api_name=False,
358
+ )
359
+
360
+ gr.on(
361
+ triggers=[
362
+ prompt.submit,
363
+ negative_prompt.submit,
364
+ run_button.click,
365
+ ],
366
+ fn=utils.randomize_seed_fn,
367
+ inputs=[seed, randomize_seed],
368
+ outputs=seed,
369
+ queue=False,
370
+ api_name=False,
371
+ ).then(
372
+ fn=generate,
373
+ inputs=[
374
+ prompt,
375
+ negative_prompt,
376
+ seed,
377
+ custom_width,
378
+ custom_height,
379
+ guidance_scale,
380
+ num_inference_steps,
381
+ sampler,
382
+ aspect_ratio_selector,
383
+ style_selector,
384
+ quality_selector,
385
+ use_upscaler,
386
+ upscaler_strength,
387
+ upscale_by,
388
+ add_quality_tags,
389
+ ],
390
+ outputs=[result, gr_metadata],
391
+ api_name="run",
392
+ )
393
+
394
+ if __name__ == "__main__":
395
+ demo.queue(max_size=20).launch(debug=IS_COLAB, share=IS_COLAB)