awacke1 commited on
Commit
cabb8e7
·
verified ·
1 Parent(s): e2f3017

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +280 -0
app.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ import os
4
+ import random
5
+ import uuid
6
+ import base64
7
+ import gradio as gr
8
+ import numpy as np
9
+ from PIL import Image
10
+ import spaces
11
+ import torch
12
+ import glob
13
+ from datetime import datetime
14
+ import json
15
+
16
+ from diffusers import StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler
17
+
18
+ DESCRIPTION = """# DALL•E 3 XL v2 High Fi"""
19
+
20
+ VOTE_FILE = "vote_counts.json"
21
+
22
+ def load_vote_counts():
23
+ if os.path.exists(VOTE_FILE):
24
+ with open(VOTE_FILE, "r") as f:
25
+ return json.load(f)
26
+ return {}
27
+
28
+ def save_vote_counts(vote_counts):
29
+ with open(VOTE_FILE, "w") as f:
30
+ json.dump(vote_counts, f)
31
+
32
+ vote_counts = load_vote_counts()
33
+
34
+ def create_download_link(filename):
35
+ with open(filename, "rb") as file:
36
+ encoded_string = base64.b64encode(file.read()).decode('utf-8')
37
+ download_link = f'<a href="data:image/png;base64,{encoded_string}" download="{filename}">Download Image</a>'
38
+ return download_link
39
+
40
+ def save_image(img, prompt):
41
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
42
+ filename = f"{timestamp}_{prompt[:50]}.png" # Limit filename length
43
+ img.save(filename)
44
+ return filename
45
+
46
+ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
47
+ if randomize_seed:
48
+ seed = random.randint(0, MAX_SEED)
49
+ return seed
50
+
51
+ def get_image_gallery():
52
+ image_files = glob.glob("*.png")
53
+ image_files.sort(key=lambda x: calculate_score(x), reverse=True)
54
+ return [(file, f"{file}\n👍 {vote_counts.get(file, {}).get('likes', 0)} 👎 {vote_counts.get(file, {}).get('dislikes', 0)} ❤️ {vote_counts.get(file, {}).get('hearts', 0)}") for file in image_files]
55
+
56
+ def calculate_score(filename):
57
+ counts = vote_counts.get(filename, {})
58
+ return (counts.get('hearts', 0) * 5) + counts.get('likes', 0) - counts.get('dislikes', 0)
59
+
60
+ def delete_all_images():
61
+ for file in glob.glob("*.png"):
62
+ os.remove(file)
63
+ vote_counts.clear()
64
+ save_vote_counts(vote_counts)
65
+ return get_image_gallery()
66
+
67
+ def vote(filename, vote_type):
68
+ if filename not in vote_counts:
69
+ vote_counts[filename] = {'likes': 0, 'dislikes': 0, 'hearts': 0}
70
+ vote_counts[filename][vote_type] += 1
71
+ save_vote_counts(vote_counts)
72
+ return get_image_gallery()
73
+
74
+ MAX_SEED = np.iinfo(np.int32).max
75
+
76
+ if not torch.cuda.is_available():
77
+ DESCRIPTION += "\n<p>Running on CPU 🥶 This demo may not work on CPU.</p>"
78
+
79
+ USE_TORCH_COMPILE = 0
80
+ ENABLE_CPU_OFFLOAD = 0
81
+
82
+ if torch.cuda.is_available():
83
+ pipe = StableDiffusionXLPipeline.from_pretrained(
84
+ "fluently/Fluently-XL-v4",
85
+ torch_dtype=torch.float16,
86
+ use_safetensors=True,
87
+ )
88
+ pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
89
+
90
+ pipe.load_lora_weights("ehristoforu/dalle-3-xl-v2", weight_name="dalle-3-xl-lora-v2.safetensors", adapter_name="dalle")
91
+ pipe.set_adapters("dalle")
92
+
93
+ pipe.to("cuda")
94
+
95
+ @spaces.GPU(enable_queue=True)
96
+ def generate(
97
+ prompt: str,
98
+ negative_prompt: str = "",
99
+ use_negative_prompt: bool = False,
100
+ seed: int = 0,
101
+ width: int = 1024,
102
+ height: int = 1024,
103
+ guidance_scale: float = 3,
104
+ randomize_seed: bool = False,
105
+ progress=gr.Progress(track_tqdm=True),
106
+ ):
107
+ seed = int(randomize_seed_fn(seed, randomize_seed))
108
+
109
+ if not use_negative_prompt:
110
+ negative_prompt = ""
111
+
112
+ images = pipe(
113
+ prompt=prompt,
114
+ negative_prompt=negative_prompt,
115
+ width=width,
116
+ height=height,
117
+ guidance_scale=guidance_scale,
118
+ num_inference_steps=20,
119
+ num_images_per_prompt=1,
120
+ cross_attention_kwargs={"scale": 0.65},
121
+ output_type="pil",
122
+ ).images
123
+ image_paths = [save_image(img, prompt) for img in images]
124
+ download_links = [create_download_link(path) for path in image_paths]
125
+
126
+ return image_paths, seed, download_links, get_image_gallery()
127
+
128
+ examples = [
129
+ "A bustling cityscape at sunset, where sleek, translucent skyscrapers shimmer with holographic displays. Gentle AI-powered drones weave through the air, delivering packages and tending to floating gardens. In the foreground, a diverse group of humans and lifelike androids share a laugh at a streetside café, their animated conversation bringing warmth to the scene.",
130
+ # ... (other examples remain unchanged)
131
+ ]
132
+
133
+ css = '''
134
+ .gradio-container{max-width: 1024px !important}
135
+ h1{text-align:center}
136
+ footer {
137
+ visibility: hidden
138
+ }
139
+ '''
140
+
141
+ with gr.Blocks(css=css, theme="pseudolab/huggingface-korea-theme") as demo:
142
+ gr.Markdown(DESCRIPTION)
143
+
144
+ with gr.Group():
145
+ with gr.Row():
146
+ prompt = gr.Text(
147
+ label="Prompt",
148
+ show_label=False,
149
+ max_lines=1,
150
+ placeholder="Enter your prompt",
151
+ container=False,
152
+ )
153
+ run_button = gr.Button("Run", scale=0)
154
+ result = gr.Gallery(label="Result", columns=1, preview=True, show_label=False)
155
+ with gr.Accordion("Advanced options", open=False):
156
+ use_negative_prompt = gr.Checkbox(label="Use negative prompt", value=True)
157
+ negative_prompt = gr.Text(
158
+ label="Negative prompt",
159
+ lines=4,
160
+ max_lines=6,
161
+ value="""(deformed, distorted, disfigured:1.3), poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, (mutated hands and fingers:1.4), disconnected limbs, mutation, mutated, ugly, disgusting, blurry, amputation, (NSFW:1.25)""",
162
+ placeholder="Enter a negative prompt",
163
+ visible=True,
164
+ )
165
+ seed = gr.Slider(
166
+ label="Seed",
167
+ minimum=0,
168
+ maximum=MAX_SEED,
169
+ step=1,
170
+ value=0,
171
+ visible=True
172
+ )
173
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
174
+ with gr.Row(visible=True):
175
+ width = gr.Slider(
176
+ label="Width",
177
+ minimum=512,
178
+ maximum=2048,
179
+ step=8,
180
+ value=1920,
181
+ )
182
+ height = gr.Slider(
183
+ label="Height",
184
+ minimum=512,
185
+ maximum=2048,
186
+ step=8,
187
+ value=1080,
188
+ )
189
+ with gr.Row():
190
+ guidance_scale = gr.Slider(
191
+ label="Guidance Scale",
192
+ minimum=0.1,
193
+ maximum=20.0,
194
+ step=0.1,
195
+ value=20.0,
196
+ )
197
+
198
+ image_gallery = gr.Gallery(label="Generated Images", show_label=True, columns=4, height="auto")
199
+
200
+ with gr.Row():
201
+ delete_all_button = gr.Button("🗑️ Delete All Images")
202
+ like_button = gr.Button("👍 Like")
203
+ dislike_button = gr.Button("👎 Dislike")
204
+ heart_button = gr.Button("❤️ Heart")
205
+
206
+ selected_image = gr.State(None)
207
+
208
+ gr.Examples(
209
+ examples=examples,
210
+ inputs=prompt,
211
+ outputs=[result, seed],
212
+ fn=generate,
213
+ cache_examples=False,
214
+ )
215
+
216
+ use_negative_prompt.change(
217
+ fn=lambda x: gr.update(visible=x),
218
+ inputs=use_negative_prompt,
219
+ outputs=negative_prompt,
220
+ api_name=False,
221
+ )
222
+
223
+ delete_all_button.click(
224
+ fn=delete_all_images,
225
+ inputs=[],
226
+ outputs=[image_gallery],
227
+ )
228
+
229
+ image_gallery.select(
230
+ fn=lambda evt: evt,
231
+ inputs=[gr.State("value")],
232
+ outputs=[selected_image],
233
+ )
234
+
235
+ like_button.click(
236
+ fn=lambda x: vote(x, 'likes') if x else None,
237
+ inputs=[selected_image],
238
+ outputs=[image_gallery],
239
+ )
240
+
241
+ dislike_button.click(
242
+ fn=lambda x: vote(x, 'dislikes') if x else None,
243
+ inputs=[selected_image],
244
+ outputs=[image_gallery],
245
+ )
246
+
247
+ heart_button.click(
248
+ fn=lambda x: vote(x, 'hearts') if x else None,
249
+ inputs=[selected_image],
250
+ outputs=[image_gallery],
251
+ )
252
+
253
+ def update_gallery():
254
+ return gr.update(value=get_image_gallery())
255
+
256
+ gr.on(
257
+ triggers=[
258
+ prompt.submit,
259
+ negative_prompt.submit,
260
+ run_button.click,
261
+ ],
262
+ fn=generate,
263
+ inputs=[
264
+ prompt,
265
+ negative_prompt,
266
+ use_negative_prompt,
267
+ seed,
268
+ width,
269
+ height,
270
+ guidance_scale,
271
+ randomize_seed,
272
+ ],
273
+ outputs=[result, seed, gr.HTML(visible=False), image_gallery],
274
+ api_name="run",
275
+ )
276
+
277
+ demo.load(fn=update_gallery, outputs=image_gallery)
278
+
279
+ if __name__ == "__main__":
280
+ demo.queue(max_size=20).launch(show_api=False, debug=False)