Squaad AI commited on
Commit
c9d2694
·
verified ·
1 Parent(s): 87a518e

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +288 -0
app.py ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ from __future__ import annotations
4
+
5
+ import os
6
+ import random
7
+
8
+ import gradio as gr
9
+ import numpy as np
10
+ import PIL.Image
11
+ import spaces
12
+ import torch
13
+ from diffusers import AutoencoderKL, DiffusionPipeline
14
+
15
+ DESCRIPTION = "# Run any LoRA or SD Model"
16
+ if not torch.cuda.is_available():
17
+ DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
18
+
19
+ MAX_SEED = np.iinfo(np.int32).max
20
+ CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES") == "1"
21
+ MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "1824"))
22
+ USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE") == "1"
23
+ ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD") == "1"
24
+ ENABLE_USE_LORA = os.getenv("ENABLE_USE_LORA", "1") == "1"
25
+ ENABLE_USE_VAE = os.getenv("ENABLE_USE_VAE", "1") == "1"
26
+
27
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
28
+
29
+ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
30
+ if randomize_seed:
31
+ seed = random.randint(0, MAX_SEED)
32
+ return seed
33
+
34
+
35
+ @spaces.GPU
36
+ def generate(
37
+ prompt: str,
38
+ negative_prompt: str = "",
39
+ prompt_2: str = "",
40
+ negative_prompt_2: str = "",
41
+ use_negative_prompt: bool = False,
42
+ use_prompt_2: bool = False,
43
+ use_negative_prompt_2: bool = False,
44
+ seed: int = 0,
45
+ width: int = 1024,
46
+ height: int = 1024,
47
+ guidance_scale_base: float = 5.0,
48
+ num_inference_steps_base: int = 25,
49
+ use_vae: bool = False,
50
+ use_lora: bool = False,
51
+ model = 'stabilityai/stable-diffusion-xl-base-1.0',
52
+ vaecall = 'madebyollin/sdxl-vae-fp16-fix',
53
+ lora = '',
54
+ lora_scale: float = 0.7,
55
+ ) -> PIL.Image.Image:
56
+ if torch.cuda.is_available():
57
+
58
+ if not use_vae:
59
+ pipe = DiffusionPipeline.from_pretrained(model, torch_dtype=torch.float16)
60
+
61
+ if use_vae:
62
+ vae = AutoencoderKL.from_pretrained(vaecall, torch_dtype=torch.float16)
63
+ pipe = DiffusionPipeline.from_pretrained(model, vae=vae, torch_dtype=torch.float16)
64
+
65
+ if use_lora:
66
+ pipe.load_lora_weights(lora)
67
+ pipe.fuse_lora(lora_scale)
68
+
69
+ if ENABLE_CPU_OFFLOAD:
70
+ pipe.enable_model_cpu_offload()
71
+
72
+ else:
73
+ pipe.to(device)
74
+
75
+ if USE_TORCH_COMPILE:
76
+ pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
77
+
78
+ generator = torch.Generator().manual_seed(seed)
79
+
80
+ if not use_negative_prompt:
81
+ negative_prompt = None # type: ignore
82
+ if not use_prompt_2:
83
+ prompt_2 = None # type: ignore
84
+ if not use_negative_prompt_2:
85
+ negative_prompt_2 = None # type: ignore
86
+
87
+ return pipe(
88
+ prompt=prompt,
89
+ negative_prompt=negative_prompt,
90
+ prompt_2=prompt_2,
91
+ negative_prompt_2=negative_prompt_2,
92
+ width=width,
93
+ height=height,
94
+ guidance_scale=guidance_scale_base,
95
+ num_inference_steps=num_inference_steps_base,
96
+ generator=generator,
97
+ output_type="pil",
98
+ ).images[0]
99
+ return image
100
+
101
+
102
+ examples = [
103
+ "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
104
+ "An astronaut riding a green horse",
105
+ ]
106
+
107
+ with gr.Blocks(theme=gr.themes.Soft(), css="style.css") as demo:
108
+ gr.Markdown(DESCRIPTION)
109
+ gr.DuplicateButton(
110
+ value="Duplicate Space for private use",
111
+ elem_id="duplicate-button",
112
+ visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1",
113
+ )
114
+ with gr.Group():
115
+ model = gr.Text(label='Model', placeholder='e.g. stabilityai/stable-diffusion-xl-base-1.0')
116
+ vaecall = gr.Text(label='VAE', placeholder='e.g. madebyollin/sdxl-vae-fp16-fix')
117
+ lora = gr.Text(label='LoRA', placeholder='e.g. nerijs/pixel-art-xl')
118
+ lora_scale = gr.Slider(
119
+ info="The closer to 1, the more it will resemble LoRA, but errors may be visible.",
120
+ label="Lora Scale",
121
+ minimum=0.01,
122
+ maximum=1,
123
+ step=0.01,
124
+ value=0.7,
125
+ )
126
+ with gr.Row():
127
+ prompt = gr.Text(
128
+ placeholder="Input prompt",
129
+ label="Prompt",
130
+ show_label=False,
131
+ max_lines=1,
132
+ container=False,
133
+ )
134
+ run_button = gr.Button("Run", scale=0)
135
+ result = gr.Image(label="Result", show_label=False)
136
+ with gr.Accordion("Advanced options", open=False):
137
+ with gr.Row():
138
+ use_vae = gr.Checkbox(label='Use VAE', value=False, visible=ENABLE_USE_VAE)
139
+ use_lora = gr.Checkbox(label='Use Lora', value=False, visible=ENABLE_USE_LORA)
140
+ use_negative_prompt = gr.Checkbox(label="Use negative prompt", value=False)
141
+ use_prompt_2 = gr.Checkbox(label="Use prompt 2", value=False)
142
+ use_negative_prompt_2 = gr.Checkbox(label="Use negative prompt 2", value=False)
143
+ negative_prompt = gr.Text(
144
+ placeholder="Input Negative Prompt",
145
+ label="Negative prompt",
146
+ max_lines=1,
147
+ visible=False,
148
+ )
149
+ prompt_2 = gr.Text(
150
+ placeholder="Input Prompt 2",
151
+ label="Prompt 2",
152
+ max_lines=1,
153
+ visible=False,
154
+ )
155
+ negative_prompt_2 = gr.Text(
156
+ placeholder="Input Negative Prompt 2",
157
+ label="Negative prompt 2",
158
+ max_lines=1,
159
+ visible=False,
160
+ )
161
+
162
+ seed = gr.Slider(
163
+ label="Seed",
164
+ minimum=0,
165
+ maximum=MAX_SEED,
166
+ step=1,
167
+ value=0,
168
+ )
169
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
170
+ with gr.Row():
171
+ width = gr.Slider(
172
+ label="Width",
173
+ minimum=256,
174
+ maximum=MAX_IMAGE_SIZE,
175
+ step=32,
176
+ value=1024,
177
+ )
178
+ height = gr.Slider(
179
+ label="Height",
180
+ minimum=256,
181
+ maximum=MAX_IMAGE_SIZE,
182
+ step=32,
183
+ value=1024,
184
+ )
185
+
186
+ with gr.Row():
187
+ guidance_scale_base = gr.Slider(
188
+ info="Scale for classifier-free guidance",
189
+ label="Guidance scale",
190
+ minimum=1,
191
+ maximum=20,
192
+ step=0.1,
193
+ value=5.0,
194
+ )
195
+ num_inference_steps_base = gr.Slider(
196
+ info="Number of denoising steps",
197
+ label="Number of inference steps",
198
+ minimum=10,
199
+ maximum=100,
200
+ step=1,
201
+ value=25,
202
+ )
203
+
204
+ gr.Examples(
205
+ examples=examples,
206
+ inputs=prompt,
207
+ outputs=result,
208
+ fn=generate,
209
+ cache_examples=CACHE_EXAMPLES,
210
+ )
211
+
212
+ use_negative_prompt.change(
213
+ fn=lambda x: gr.update(visible=x),
214
+ inputs=use_negative_prompt,
215
+ outputs=negative_prompt,
216
+ queue=False,
217
+ api_name=False,
218
+ )
219
+ use_prompt_2.change(
220
+ fn=lambda x: gr.update(visible=x),
221
+ inputs=use_prompt_2,
222
+ outputs=prompt_2,
223
+ queue=False,
224
+ api_name=False,
225
+ )
226
+ use_negative_prompt_2.change(
227
+ fn=lambda x: gr.update(visible=x),
228
+ inputs=use_negative_prompt_2,
229
+ outputs=negative_prompt_2,
230
+ queue=False,
231
+ api_name=False,
232
+ )
233
+ use_vae.change(
234
+ fn=lambda x: gr.update(visible=x),
235
+ inputs=use_vae,
236
+ outputs=vaecall,
237
+ queue=False,
238
+ api_name=False,
239
+ )
240
+ use_lora.change(
241
+ fn=lambda x: gr.update(visible=x),
242
+ inputs=use_lora,
243
+ outputs=lora,
244
+ queue=False,
245
+ api_name=False,
246
+ )
247
+
248
+ gr.on(
249
+ triggers=[
250
+ prompt.submit,
251
+ negative_prompt.submit,
252
+ prompt_2.submit,
253
+ negative_prompt_2.submit,
254
+ run_button.click,
255
+ ],
256
+ fn=randomize_seed_fn,
257
+ inputs=[seed, randomize_seed],
258
+ outputs=seed,
259
+ queue=False,
260
+ api_name=False,
261
+ ).then(
262
+ fn=generate,
263
+ inputs=[
264
+ prompt,
265
+ negative_prompt,
266
+ prompt_2,
267
+ negative_prompt_2,
268
+ use_negative_prompt,
269
+ use_prompt_2,
270
+ use_negative_prompt_2,
271
+ seed,
272
+ width,
273
+ height,
274
+ guidance_scale_base,
275
+ num_inference_steps_base,
276
+ use_vae,
277
+ use_lora,
278
+ model,
279
+ vaecall,
280
+ lora,
281
+ lora_scale,
282
+ ],
283
+ outputs=result,
284
+ api_name="run",
285
+ )
286
+
287
+ if __name__ == "__main__":
288
+ demo.queue(max_size=100).launch()