awacke1 commited on
Commit
f09fbda
·
verified ·
1 Parent(s): 6ac12c5

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +320 -0
app.py ADDED
@@ -0,0 +1,320 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ import os
4
+ import random
5
+ import uuid
6
+ import base64
7
+ import gradio as gr
8
+ import numpy as np
9
+ from PIL import Image
10
+ import spaces
11
+ import torch
12
+ import glob
13
+ from datetime import datetime
14
+ import json
15
+
16
+ from diffusers import StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler
17
+
18
+ DESCRIPTION = """# DALL•E 3 XL v2 High Fi"""
19
+
20
+ VOTE_FILE = "vote_counts.json"
21
+ PROMPT_HISTORY_FILE = "prompt_history.json"
22
+
23
+ def load_json_file(filename):
24
+ if os.path.exists(filename):
25
+ with open(filename, "r") as f:
26
+ return json.load(f)
27
+ return {}
28
+
29
+ def save_json_file(data, filename):
30
+ with open(filename, "w") as f:
31
+ json.dump(data, f)
32
+
33
+ vote_counts = load_json_file(VOTE_FILE)
34
+ prompt_history = load_json_file(PROMPT_HISTORY_FILE)
35
+
36
+ def create_download_link(filename):
37
+ with open(filename, "rb") as file:
38
+ encoded_string = base64.b64encode(file.read()).decode('utf-8')
39
+ download_link = f'<a href="data:image/png;base64,{encoded_string}" download="{filename}">Download Image</a>'
40
+ return download_link
41
+
42
+ def save_image(img, prompt):
43
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
44
+ filename = f"{timestamp}_{prompt[:50]}.png" # Limit filename length
45
+ img.save(filename)
46
+ prompt_history[filename] = prompt
47
+ save_json_file(prompt_history, PROMPT_HISTORY_FILE)
48
+ return filename
49
+
50
+ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
51
+ if randomize_seed:
52
+ seed = random.randint(0, MAX_SEED)
53
+ return seed
54
+
55
+ def get_image_gallery():
56
+ image_files = glob.glob("*.png")
57
+ image_files.sort(key=lambda x: calculate_score(x), reverse=True)
58
+ return [(file, f"{file}\n👍 {vote_counts.get(file, {}).get('likes', 0)} 👎 {vote_counts.get(file, {}).get('dislikes', 0)} ❤️ {vote_counts.get(file, {}).get('hearts', 0)}\n{prompt_history.get(file, '')}") for file in image_files]
59
+
60
+ def calculate_score(filename):
61
+ counts = vote_counts.get(filename, {})
62
+ return (counts.get('hearts', 0) * 5) + counts.get('likes', 0) - counts.get('dislikes', 0)
63
+
64
+ def delete_all_images():
65
+ for file in glob.glob("*.png"):
66
+ os.remove(file)
67
+ vote_counts.clear()
68
+ prompt_history.clear()
69
+ save_json_file(vote_counts, VOTE_FILE)
70
+ save_json_file(prompt_history, PROMPT_HISTORY_FILE)
71
+ return get_image_gallery()
72
+
73
+ def delete_image(filename):
74
+ if os.path.exists(filename):
75
+ os.remove(filename)
76
+ if filename in vote_counts:
77
+ del vote_counts[filename]
78
+ if filename in prompt_history:
79
+ del prompt_history[filename]
80
+ save_json_file(vote_counts, VOTE_FILE)
81
+ save_json_file(prompt_history, PROMPT_HISTORY_FILE)
82
+ return get_image_gallery()
83
+
84
+ def vote(filename, vote_type):
85
+ if filename:
86
+ if filename not in vote_counts:
87
+ vote_counts[filename] = {'likes': 0, 'dislikes': 0, 'hearts': 0}
88
+ vote_counts[filename][vote_type] += 1
89
+ save_json_file(vote_counts, VOTE_FILE)
90
+ return get_image_gallery()
91
+
92
+ def get_random_style():
93
+ styles = [
94
+ "Impressionist", "Cubist", "Surrealist", "Abstract Expressionist",
95
+ "Pop Art", "Minimalist", "Baroque", "Art Nouveau", "Pointillist", "Fauvism"
96
+ ]
97
+ return random.choice(styles)
98
+
99
+ MAX_SEED = np.iinfo(np.int32).max
100
+
101
+ if not torch.cuda.is_available():
102
+ DESCRIPTION += "\n<p>Running on CPU 🥶 This demo may not work on CPU.</p>"
103
+
104
+ USE_TORCH_COMPILE = 0
105
+ ENABLE_CPU_OFFLOAD = 0
106
+
107
+ if torch.cuda.is_available():
108
+ pipe = StableDiffusionXLPipeline.from_pretrained(
109
+ "fluently/Fluently-XL-v4",
110
+ torch_dtype=torch.float16,
111
+ use_safetensors=True,
112
+ )
113
+ pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
114
+
115
+ pipe.load_lora_weights("ehristoforu/dalle-3-xl-v2", weight_name="dalle-3-xl-lora-v2.safetensors", adapter_name="dalle")
116
+ pipe.set_adapters("dalle")
117
+
118
+ pipe.to("cuda")
119
+
120
+ @spaces.GPU(enable_queue=True)
121
+ def generate(
122
+ prompt: str,
123
+ negative_prompt: str = "",
124
+ use_negative_prompt: bool = False,
125
+ seed: int = 0,
126
+ width: int = 1024,
127
+ height: int = 1024,
128
+ guidance_scale: float = 3,
129
+ randomize_seed: bool = False,
130
+ progress=gr.Progress(track_tqdm=True),
131
+ ):
132
+ seed = int(randomize_seed_fn(seed, randomize_seed))
133
+
134
+ if not use_negative_prompt:
135
+ negative_prompt = ""
136
+
137
+ images = pipe(
138
+ prompt=prompt,
139
+ negative_prompt=negative_prompt,
140
+ width=width,
141
+ height=height,
142
+ guidance_scale=guidance_scale,
143
+ num_inference_steps=20,
144
+ num_images_per_prompt=1,
145
+ cross_attention_kwargs={"scale": 0.65},
146
+ output_type="pil",
147
+ ).images
148
+ image_paths = [save_image(img, prompt) for img in images]
149
+ download_links = [create_download_link(path) for path in image_paths]
150
+
151
+ return image_paths, seed, download_links, get_image_gallery()
152
+
153
+ examples = [
154
+ f"{get_random_style()} painting of a majestic lighthouse on a rocky coast. Use bold brushstrokes and a vibrant color palette to capture the interplay of light and shadow as the lighthouse beam cuts through a stormy night sky.",
155
+ f"{get_random_style()} still life featuring a pair of vintage eyeglasses. Focus on the intricate details of the frames and lenses, using a warm color scheme to evoke a sense of nostalgia and wisdom.",
156
+ f"{get_random_style()} depiction of a rustic wooden stool in a sunlit artist's studio. Emphasize the texture of the wood and the interplay of light and shadow, using a mix of earthy tones and highlights.",
157
+ f"{get_random_style()} scene viewed through an ornate window frame. Contrast the intricate details of the window with a dreamy, soft-focus landscape beyond, using a palette that transitions from cool interior tones to warm exterior hues.",
158
+ f"{get_random_style()} close-up study of interlaced fingers. Use a monochromatic color scheme to emphasize the form and texture of the hands, with dramatic lighting to create depth and emotion.",
159
+ f"{get_random_style()} composition featuring a set of dice in motion. Capture the energy and randomness of the throw, using a dynamic color palette and blurred lines to convey movement.",
160
+ f"{get_random_style()} interpretation of heaven. Create an ethereal atmosphere with soft, billowing clouds and radiant light, using a palette of celestial blues, golds, and whites.",
161
+ f"{get_random_style()} portrayal of an ancient, mystical gate. Combine architectural details with elements of fantasy, using a rich, jewel-toned palette to create an air of mystery and magic.",
162
+ f"{get_random_style()} portrait of a curious cat. Focus on capturing the feline's expressive eyes and sleek form, using a mix of bold and subtle colors to bring out the cat's personality.",
163
+ f"{get_random_style()} abstract representation of toes in sand. Use textured brushstrokes to convey the feeling of warm sand, with a palette inspired by a sun-drenched beach."
164
+ ]
165
+
166
+ css = '''
167
+ .gradio-container{max-width: 1024px !important}
168
+ h1{text-align:center}
169
+ footer {
170
+ visibility: hidden
171
+ }
172
+ '''
173
+
174
+ with gr.Blocks(css=css, theme="pseudolab/huggingface-korea-theme") as demo:
175
+ gr.Markdown(DESCRIPTION)
176
+
177
+ with gr.Group():
178
+ with gr.Row():
179
+ prompt = gr.Text(
180
+ label="Prompt",
181
+ show_label=False,
182
+ max_lines=1,
183
+ placeholder="Enter your prompt",
184
+ container=False,
185
+ )
186
+ run_button = gr.Button("Run", scale=0)
187
+ result = gr.Gallery(label="Result", columns=1, preview=True, show_label=False)
188
+ with gr.Accordion("Advanced options", open=False):
189
+ use_negative_prompt = gr.Checkbox(label="Use negative prompt", value=True)
190
+ negative_prompt = gr.Text(
191
+ label="Negative prompt",
192
+ lines=4,
193
+ max_lines=6,
194
+ value="""(deformed, distorted, disfigured:1.3), poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, (mutated hands and fingers:1.4), disconnected limbs, mutation, mutated, ugly, disgusting, blurry, amputation, (NSFW:1.25)""",
195
+ placeholder="Enter a negative prompt",
196
+ visible=True,
197
+ )
198
+ seed = gr.Slider(
199
+ label="Seed",
200
+ minimum=0,
201
+ maximum=MAX_SEED,
202
+ step=1,
203
+ value=0,
204
+ visible=True
205
+ )
206
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
207
+ with gr.Row(visible=True):
208
+ width = gr.Slider(
209
+ label="Width",
210
+ minimum=512,
211
+ maximum=2048,
212
+ step=8,
213
+ value=1920,
214
+ )
215
+ height = gr.Slider(
216
+ label="Height",
217
+ minimum=512,
218
+ maximum=2048,
219
+ step=8,
220
+ value=1080,
221
+ )
222
+ with gr.Row():
223
+ guidance_scale = gr.Slider(
224
+ label="Guidance Scale",
225
+ minimum=0.1,
226
+ maximum=20.0,
227
+ step=0.1,
228
+ value=20.0,
229
+ )
230
+
231
+ image_gallery = gr.Gallery(label="Generated Images", show_label=True, columns=4, height="auto")
232
+
233
+ with gr.Row():
234
+ delete_all_button = gr.Button("🗑️ Delete All Images")
235
+ like_button = gr.Button("👍 Like")
236
+ dislike_button = gr.Button("👎 Dislike")
237
+ heart_button = gr.Button("❤️ Heart")
238
+ delete_image_button = gr.Button("🗑️ Delete Selected Image")
239
+
240
+ selected_image = gr.State(None)
241
+
242
+ gr.Examples(
243
+ examples=examples,
244
+ inputs=prompt,
245
+ outputs=[result, seed],
246
+ fn=generate,
247
+ cache_examples=False,
248
+ )
249
+
250
+ use_negative_prompt.change(
251
+ fn=lambda x: gr.update(visible=x),
252
+ inputs=use_negative_prompt,
253
+ outputs=negative_prompt,
254
+ api_name=False,
255
+ )
256
+
257
+ delete_all_button.click(
258
+ fn=delete_all_images,
259
+ inputs=[],
260
+ outputs=[image_gallery],
261
+ )
262
+
263
+ image_gallery.select(
264
+ fn=lambda evt: evt,
265
+ inputs=[gr.State("value")],
266
+ outputs=[selected_image],
267
+ )
268
+
269
+ like_button.click(
270
+ fn=lambda x: vote(x, 'likes'),
271
+ inputs=[selected_image],
272
+ outputs=[image_gallery],
273
+ )
274
+
275
+ dislike_button.click(
276
+ fn=lambda x: vote(x, 'dislikes'),
277
+ inputs=[selected_image],
278
+ outputs=[image_gallery],
279
+ )
280
+
281
+ heart_button.click(
282
+ fn=lambda x: vote(x, 'hearts'),
283
+ inputs=[selected_image],
284
+ outputs=[image_gallery],
285
+ )
286
+
287
+ delete_image_button.click(
288
+ fn=delete_image,
289
+ inputs=[selected_image],
290
+ outputs=[image_gallery],
291
+ )
292
+
293
+ def update_gallery():
294
+ return gr.update(value=get_image_gallery())
295
+
296
+ gr.on(
297
+ triggers=[
298
+ prompt.submit,
299
+ negative_prompt.submit,
300
+ run_button.click,
301
+ ],
302
+ fn=generate,
303
+ inputs=[
304
+ prompt,
305
+ negative_prompt,
306
+ use_negative_prompt,
307
+ seed,
308
+ width,
309
+ height,
310
+ guidance_scale,
311
+ randomize_seed,
312
+ ],
313
+ outputs=[result, seed, gr.HTML(visible=False), image_gallery],
314
+ api_name="run",
315
+ )
316
+
317
+ demo.load(fn=update_gallery, outputs=image_gallery)
318
+
319
+ if __name__ == "__main__":
320
+ demo.queue(max_size=20).launch(show_api=False, debug=False)