ameerazam08 commited on
Commit
5f6eaf4
·
verified ·
1 Parent(s): 102a2b2
Files changed (1) hide show
  1. app.py +321 -0
app.py ADDED
@@ -0,0 +1,321 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ from __future__ import annotations
4
+
5
+ import os
6
+ import random
7
+
8
+ import gradio as gr
9
+ import numpy as np
10
+ import PIL.Image
11
+ import spaces
12
+ import torch
13
+ from diffusers import AutoencoderKL, DiffusionPipeline
14
+
15
+ DESCRIPTION = """
16
+ # TempestV0.1
17
+
18
+ **Demo by [ameer azam) - [Twitter](https://twitter.com/Ameerazam18) - [GitHub](https://github.com/AMEERAZAM08)) - [Hugging Face](https://huggingface.co/ameerazam08)**
19
+
20
+ This is a demo of <a href="https://huggingface.co/dataautogpt3/TempestV0.1">TempestV0.1</a> by @dataautogpt3.
21
+
22
+ **The code for this demo is based on [@hysts's SD-XL demo](https://huggingface.co/spaces/hysts/SD-XL) running on a A10G GPU.**
23
+
24
+ **NOTE: The model is licensed under a non-commercial license**
25
+
26
+ """
27
+ if not torch.cuda.is_available():
28
+ DESCRIPTION += "\n<h1>Running on CPU 🥶 This demo does not work on CPU.</a> instead</h1>"
29
+
30
+ MAX_SEED = np.iinfo(np.int32).max
31
+ CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES", "0") == "1"
32
+ MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "1024"))
33
+ USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE") == "1"
34
+ ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD") == "1"
35
+ ENABLE_REFINER = os.getenv("ENABLE_REFINER", "0") == "1"
36
+
37
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
38
+ if torch.cuda.is_available():
39
+ vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
40
+ pipe = DiffusionPipeline.from_pretrained(
41
+ "dataautogpt3/TempestV0.1",
42
+ vae=vae,
43
+ torch_dtype=torch.float16,
44
+ # variant="fp16",
45
+ )
46
+ if ENABLE_REFINER:
47
+ refiner = DiffusionPipeline.from_pretrained(
48
+ "stabilityai/stable-diffusion-xl-refiner-1.0",
49
+ vae=vae,
50
+ torch_dtype=torch.float16,
51
+ # variant="fp16",
52
+ )
53
+
54
+ if ENABLE_CPU_OFFLOAD:
55
+ pipe.enable_model_cpu_offload()
56
+ if ENABLE_REFINER:
57
+ refiner.enable_model_cpu_offload()
58
+ else:
59
+ pipe.to(device)
60
+ if ENABLE_REFINER:
61
+ refiner.to(device)
62
+
63
+ if USE_TORCH_COMPILE:
64
+ pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
65
+ if ENABLE_REFINER:
66
+ refiner.unet = torch.compile(refiner.unet, mode="reduce-overhead", fullgraph=True)
67
+
68
+
69
+ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
70
+ if randomize_seed:
71
+ seed = random.randint(0, MAX_SEED)
72
+ return seed
73
+
74
+
75
+ @spaces.GPU(enable_queue=True)
76
+ def generate(
77
+ prompt: str,
78
+ negative_prompt: str = "",
79
+ prompt_2: str = "",
80
+ negative_prompt_2: str = "",
81
+ use_negative_prompt: bool = False,
82
+ use_prompt_2: bool = False,
83
+ use_negative_prompt_2: bool = False,
84
+ seed: int = 0,
85
+ width: int = 1024,
86
+ height: int = 1024,
87
+ guidance_scale_base: float = 5.0,
88
+ guidance_scale_refiner: float = 5.0,
89
+ num_inference_steps_base: int = 25,
90
+ num_inference_steps_refiner: int = 25,
91
+ apply_refiner: bool = False,
92
+ progress=gr.Progress(track_tqdm=True),
93
+ ) -> PIL.Image.Image:
94
+ print(f"** Generating image for: \"{prompt}\" **")
95
+ generator = torch.Generator().manual_seed(seed)
96
+
97
+ if not use_negative_prompt:
98
+ negative_prompt = None # type: ignore
99
+ if not use_prompt_2:
100
+ prompt_2 = None # type: ignore
101
+ if not use_negative_prompt_2:
102
+ negative_prompt_2 = None # type: ignore
103
+
104
+ if not apply_refiner:
105
+ return pipe(
106
+ prompt=prompt,
107
+ negative_prompt=negative_prompt,
108
+ prompt_2=prompt_2,
109
+ negative_prompt_2=negative_prompt_2,
110
+ width=width,
111
+ height=height,
112
+ guidance_scale=guidance_scale_base,
113
+ num_inference_steps=num_inference_steps_base,
114
+ generator=generator,
115
+ output_type="pil",
116
+ ).images[0]
117
+ else:
118
+ latents = pipe(
119
+ prompt=prompt,
120
+ negative_prompt=negative_prompt,
121
+ prompt_2=prompt_2,
122
+ negative_prompt_2=negative_prompt_2,
123
+ width=width,
124
+ height=height,
125
+ guidance_scale=guidance_scale_base,
126
+ num_inference_steps=num_inference_steps_base,
127
+ generator=generator,
128
+ output_type="latent",
129
+ ).images
130
+ image = refiner(
131
+ prompt=prompt,
132
+ negative_prompt=negative_prompt,
133
+ prompt_2=prompt_2,
134
+ negative_prompt_2=negative_prompt_2,
135
+ guidance_scale=guidance_scale_refiner,
136
+ num_inference_steps=num_inference_steps_refiner,
137
+ image=latents,
138
+ generator=generator,
139
+ ).images[0]
140
+ return image
141
+
142
+
143
+ examples = [
144
+ "A realistic photograph of an astronaut in a jungle, cold color palette, detailed, 8k",
145
+ "An astronaut riding a green horse",
146
+ ]
147
+
148
+ theme = gr.themes.Base(
149
+ font=[gr.themes.GoogleFont('Libre Franklin'), gr.themes.GoogleFont('Public Sans'), 'system-ui', 'sans-serif'],
150
+ )
151
+ with gr.Blocks(css="footer{display:none !important}", theme=theme) as demo:
152
+ gr.Markdown(DESCRIPTION)
153
+ gr.DuplicateButton(
154
+ value="Duplicate Space for private use",
155
+ elem_id="duplicate-button",
156
+ visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1",
157
+ )
158
+ with gr.Group():
159
+ prompt = gr.Text(
160
+ label="Prompt",
161
+ show_label=False,
162
+ max_lines=1,
163
+ container=False,
164
+ placeholder="Enter your prompt",
165
+ )
166
+ run_button = gr.Button("Generate")
167
+ result = gr.Image(label="Result", show_label=False)
168
+ with gr.Accordion("Advanced options", open=False):
169
+ with gr.Row():
170
+ use_negative_prompt = gr.Checkbox(label="Use negative prompt", value=False)
171
+ use_prompt_2 = gr.Checkbox(label="Use prompt 2", value=False)
172
+ use_negative_prompt_2 = gr.Checkbox(label="Use negative prompt 2", value=False)
173
+ negative_prompt = gr.Text(
174
+ label="Negative prompt",
175
+ max_lines=1,
176
+ placeholder="Enter a negative prompt",
177
+ visible=False,
178
+ )
179
+ prompt_2 = gr.Text(
180
+ label="Prompt 2",
181
+ max_lines=1,
182
+ placeholder="Enter your prompt",
183
+ visible=False,
184
+ )
185
+ negative_prompt_2 = gr.Text(
186
+ label="Negative prompt 2",
187
+ max_lines=1,
188
+ placeholder="Enter a negative prompt",
189
+ visible=False,
190
+ )
191
+
192
+ seed = gr.Slider(
193
+ label="Seed",
194
+ minimum=0,
195
+ maximum=MAX_SEED,
196
+ step=1,
197
+ value=0,
198
+ )
199
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
200
+ with gr.Row():
201
+ width = gr.Slider(
202
+ label="Width",
203
+ minimum=256,
204
+ maximum=MAX_IMAGE_SIZE,
205
+ step=32,
206
+ value=1024,
207
+ )
208
+ height = gr.Slider(
209
+ label="Height",
210
+ minimum=256,
211
+ maximum=MAX_IMAGE_SIZE,
212
+ step=32,
213
+ value=1024,
214
+ )
215
+ apply_refiner = gr.Checkbox(label="Apply refiner", value=False, visible=ENABLE_REFINER)
216
+ with gr.Row():
217
+ guidance_scale_base = gr.Slider(
218
+ label="Guidance scale for base",
219
+ minimum=1,
220
+ maximum=20,
221
+ step=0.1,
222
+ value=5.0,
223
+ )
224
+ num_inference_steps_base = gr.Slider(
225
+ label="Number of inference steps for base",
226
+ minimum=10,
227
+ maximum=100,
228
+ step=1,
229
+ value=25,
230
+ )
231
+ with gr.Row(visible=False) as refiner_params:
232
+ guidance_scale_refiner = gr.Slider(
233
+ label="Guidance scale for refiner",
234
+ minimum=1,
235
+ maximum=20,
236
+ step=0.1,
237
+ value=5.0,
238
+ )
239
+ num_inference_steps_refiner = gr.Slider(
240
+ label="Number of inference steps for refiner",
241
+ minimum=10,
242
+ maximum=100,
243
+ step=1,
244
+ value=25,
245
+ )
246
+
247
+ gr.Examples(
248
+ examples=examples,
249
+ inputs=prompt,
250
+ outputs=result,
251
+ fn=generate,
252
+ cache_examples=CACHE_EXAMPLES,
253
+ )
254
+
255
+ use_negative_prompt.change(
256
+ fn=lambda x: gr.update(visible=x),
257
+ inputs=use_negative_prompt,
258
+ outputs=negative_prompt,
259
+ queue=False,
260
+ api_name=False,
261
+ )
262
+ use_prompt_2.change(
263
+ fn=lambda x: gr.update(visible=x),
264
+ inputs=use_prompt_2,
265
+ outputs=prompt_2,
266
+ queue=False,
267
+ api_name=False,
268
+ )
269
+ use_negative_prompt_2.change(
270
+ fn=lambda x: gr.update(visible=x),
271
+ inputs=use_negative_prompt_2,
272
+ outputs=negative_prompt_2,
273
+ queue=False,
274
+ api_name=False,
275
+ )
276
+ apply_refiner.change(
277
+ fn=lambda x: gr.update(visible=x),
278
+ inputs=apply_refiner,
279
+ outputs=refiner_params,
280
+ queue=False,
281
+ api_name=False,
282
+ )
283
+
284
+ gr.on(
285
+ triggers=[
286
+ prompt.submit,
287
+ negative_prompt.submit,
288
+ prompt_2.submit,
289
+ negative_prompt_2.submit,
290
+ run_button.click,
291
+ ],
292
+ fn=randomize_seed_fn,
293
+ inputs=[seed, randomize_seed],
294
+ outputs=seed,
295
+ queue=False,
296
+ api_name=False,
297
+ ).then(
298
+ fn=generate,
299
+ inputs=[
300
+ prompt,
301
+ negative_prompt,
302
+ prompt_2,
303
+ negative_prompt_2,
304
+ use_negative_prompt,
305
+ use_prompt_2,
306
+ use_negative_prompt_2,
307
+ seed,
308
+ width,
309
+ height,
310
+ guidance_scale_base,
311
+ guidance_scale_refiner,
312
+ num_inference_steps_base,
313
+ num_inference_steps_refiner,
314
+ apply_refiner,
315
+ ],
316
+ outputs=result,
317
+ api_name="run",
318
+ )
319
+
320
+ if __name__ == "__main__":
321
+ demo.queue(max_size=20, api_open=False).launch(show_api=False)