awacke1 commited on
Commit
e720777
·
verified ·
1 Parent(s): 6fb1bdd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +238 -86
app.py CHANGED
@@ -1,48 +1,51 @@
1
  import os
2
  import random
 
 
3
  import gradio as gr
4
  import numpy as np
5
  from PIL import Image
 
6
  import torch
7
  import glob
8
  from datetime import datetime
9
  import pandas as pd
10
  import json
11
  import re
12
- import logging
13
- from diffusers import StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler
14
 
15
- # Set up logging
16
- logging.basicConfig(level=logging.INFO)
17
- logger = logging.getLogger(__name__)
18
 
19
  DESCRIPTION = """# 🎨 ArtForge: Community AI Gallery
20
- Create, curate, and compete with AI-generated art. Join our creative multiplayer experience! 🖼️🏆✨"""
21
 
22
- METADATA_FILE = "image_metadata.json"
23
- MAX_SEED = np.iinfo(np.int32).max
24
 
25
  # Global variables
26
  image_metadata = pd.DataFrame(columns=['Filename', 'Prompt', 'Likes', 'Dislikes', 'Hearts', 'Created'])
 
27
 
28
- def load_metadata():
29
- global image_metadata
30
- if os.path.exists(METADATA_FILE):
31
- with open(METADATA_FILE, 'r') as f:
32
- image_metadata = pd.DataFrame(json.load(f))
33
- else:
34
- image_metadata = pd.DataFrame(columns=['Filename', 'Prompt', 'Likes', 'Dislikes', 'Hearts', 'Created'])
35
 
36
- def save_metadata():
37
- with open(METADATA_FILE, 'w') as f:
38
- json.dump(image_metadata.to_dict('records'), f)
39
 
40
- load_metadata()
 
 
 
 
 
 
41
 
42
  def save_image(img, prompt):
43
- global image_metadata
44
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
45
- safe_prompt = re.sub(r'[^\w\s-]', '', prompt.lower())[:50]
46
  safe_prompt = re.sub(r'[-\s]+', '-', safe_prompt).strip('-')
47
  filename = f"{timestamp}_{safe_prompt}.png"
48
  img.save(filename)
@@ -52,51 +55,75 @@ def save_image(img, prompt):
52
  'Likes': [0],
53
  'Dislikes': [0],
54
  'Hearts': [0],
55
- 'Created': [str(datetime.now())]
56
  })
57
- image_metadata = pd.concat([image_metadata, new_row], ignore_index=True, sort=False)
58
- save_metadata()
59
- logger.info(f"Saved new image: {filename}")
60
  return filename
61
 
 
 
 
 
 
62
  def get_image_gallery():
63
- return [(file, get_image_caption(file)) for file in image_metadata['Filename'] if os.path.exists(file)]
 
 
64
 
65
  def get_image_caption(filename):
66
- if filename in image_metadata['Filename'].values:
67
- row = image_metadata[image_metadata['Filename'] == filename].iloc[0]
68
- return f"{filename}\nPrompt: {row['Prompt']}\n👍 {row['Likes']} 👎 {row['Dislikes']} ❤️ {row['Hearts']}"
 
 
 
 
69
  return filename
70
 
71
  def delete_all_images():
72
- global image_metadata
73
  for file in image_metadata['Filename']:
74
  if os.path.exists(file):
75
  os.remove(file)
76
  image_metadata = pd.DataFrame(columns=['Filename', 'Prompt', 'Likes', 'Dislikes', 'Hearts', 'Created'])
77
- save_metadata()
78
- logger.info("All images deleted")
79
  return get_image_gallery(), image_metadata.values.tolist()
80
 
81
  def delete_image(filename):
82
- global image_metadata
83
  if filename and os.path.exists(filename):
84
  os.remove(filename)
85
  image_metadata = image_metadata[image_metadata['Filename'] != filename]
86
- save_metadata()
87
- logger.info(f"Deleted image: {filename}")
 
88
  return get_image_gallery(), image_metadata.values.tolist()
89
 
90
  def vote(filename, vote_type):
91
- global image_metadata
92
- if filename in image_metadata['Filename'].values:
93
- image_metadata.loc[image_metadata['Filename'] == filename, vote_type] += 1
94
- save_metadata()
95
- logger.info(f"Updated {vote_type} count for {filename}")
96
- else:
97
- logger.warning(f"File {filename} not found in metadata")
98
  return get_image_gallery(), image_metadata.values.tolist()
99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  if torch.cuda.is_available():
101
  pipe = StableDiffusionXLPipeline.from_pretrained(
102
  "fluently/Fluently-XL-v4",
@@ -104,22 +131,29 @@ if torch.cuda.is_available():
104
  use_safetensors=True,
105
  )
106
  pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
 
107
  pipe.load_lora_weights("ehristoforu/dalle-3-xl-v2", weight_name="dalle-3-xl-lora-v2.safetensors", adapter_name="dalle")
108
  pipe.set_adapters("dalle")
 
109
  pipe.to("cuda")
110
- else:
111
- DESCRIPTION += "\n<p>Running on CPU 🥶 This demo may not work on CPU.</p>"
112
 
 
113
  def generate(
114
- prompt, negative_prompt="", use_negative_prompt=False,
115
- seed=0, width=1024, height=1024, guidance_scale=3, randomize_seed=False,
116
- progress=gr.Progress(track_tqdm=True)
 
 
 
 
 
 
117
  ):
118
- if randomize_seed:
119
- seed = random.randint(0, MAX_SEED)
120
  if not use_negative_prompt:
121
  negative_prompt = ""
122
-
123
  images = pipe(
124
  prompt=prompt,
125
  negative_prompt=negative_prompt,
@@ -132,12 +166,29 @@ def generate(
132
  output_type="pil",
133
  ).images
134
  image_paths = [save_image(img, prompt) for img in images]
135
- return image_paths, seed, get_image_gallery(), image_metadata.values.tolist()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
 
137
  css = '''
138
  .gradio-container{max-width: 1024px !important}
139
  h1{text-align:center}
140
- footer {visibility: hidden}
 
 
141
  '''
142
 
143
  with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
@@ -145,53 +196,154 @@ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
145
 
146
  with gr.Tab("Generate Images"):
147
  with gr.Group():
148
- prompt = gr.Text(label="Prompt", placeholder="Enter your prompt")
149
- run_button = gr.Button("Generate")
150
- result = gr.Gallery(label="Result", columns=1, preview=True)
 
 
 
 
 
 
 
151
  with gr.Accordion("Advanced options", open=False):
152
- use_negative_prompt = gr.Checkbox(label="Use negative prompt", value=False)
153
- negative_prompt = gr.Text(label="Negative prompt", visible=False)
154
- seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
  randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
156
- width = gr.Slider(label="Width", minimum=512, maximum=2048, step=8, value=1024)
157
- height = gr.Slider(label="Height", minimum=512, maximum=2048, step=8, value=1024)
158
- guidance_scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=20.0, step=0.1, value=7.5)
159
-
160
- with gr.Tab("Gallery"):
161
- image_gallery = gr.Gallery(label="Generated Images", columns=4, height="auto")
162
- delete_image_button = gr.Button("🗑️ Delete Selected Image")
163
- selected_image = gr.State(None)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
 
165
- with gr.Tab("Metadata and Management"):
166
- metadata_df = gr.Dataframe(label="Image Metadata", headers=["Filename", "Prompt", "Likes", "Dislikes", "Hearts", "Created"])
 
167
  with gr.Row():
168
  like_button = gr.Button("👍 Like")
169
  dislike_button = gr.Button("👎 Dislike")
170
  heart_button = gr.Button("❤️ Heart")
 
 
 
 
 
 
 
 
 
 
171
  delete_all_button = gr.Button("🗑️ Delete All Images")
172
 
173
- use_negative_prompt.change(lambda x: gr.update(visible=x), inputs=use_negative_prompt, outputs=negative_prompt)
 
 
 
 
 
174
 
175
- run_button.click(
176
- fn=generate,
177
- inputs=[prompt, negative_prompt, use_negative_prompt, seed, width, height, guidance_scale, randomize_seed],
178
- outputs=[result, seed, image_gallery, metadata_df]
179
  )
180
 
181
- delete_all_button.click(fn=delete_all_images, outputs=[image_gallery, metadata_df])
182
-
183
- image_gallery.select(fn=lambda evt, x=None: evt, inputs=None, outputs=selected_image)
184
-
185
- delete_image_button.click(fn=delete_image, inputs=[selected_image], outputs=[image_gallery, metadata_df])
186
 
187
- for button, vote_type in [(like_button, 'Likes'), (dislike_button, 'Dislikes'), (heart_button, 'Hearts')]:
188
- button.click(
189
- fn=lambda x, vt=vote_type: vote(x, vt) if x else (get_image_gallery(), image_metadata.values.tolist()),
190
- inputs=[selected_image],
191
- outputs=[image_gallery, metadata_df]
192
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
 
194
- demo.load(fn=lambda: (get_image_gallery(), image_metadata.values.tolist()), outputs=[image_gallery, metadata_df])
195
 
196
  if __name__ == "__main__":
197
- demo.queue(max_size=20).launch(debug=False)
 
1
  import os
2
  import random
3
+ import uuid
4
+ import base64
5
  import gradio as gr
6
  import numpy as np
7
  from PIL import Image
8
+ import spaces
9
  import torch
10
  import glob
11
  from datetime import datetime
12
  import pandas as pd
13
  import json
14
  import re
 
 
15
 
16
+ from diffusers import StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler
 
 
17
 
18
  DESCRIPTION = """# 🎨 ArtForge: Community AI Gallery
 
19
 
20
+ Create, curate, and compete with AI-generated art. Join our creative multiplayer experience! 🖼️🏆✨
21
+ """
22
 
23
  # Global variables
24
  image_metadata = pd.DataFrame(columns=['Filename', 'Prompt', 'Likes', 'Dislikes', 'Hearts', 'Created'])
25
+ LIKES_CACHE_FILE = "likes_cache.json"
26
 
27
+ def load_likes_cache():
28
+ if os.path.exists(LIKES_CACHE_FILE):
29
+ with open(LIKES_CACHE_FILE, 'r') as f:
30
+ return json.load(f)
31
+ return {}
 
 
32
 
33
+ def save_likes_cache(cache):
34
+ with open(LIKES_CACHE_FILE, 'w') as f:
35
+ json.dump(cache, f)
36
 
37
+ likes_cache = load_likes_cache()
38
+
39
+ def create_download_link(filename):
40
+ with open(filename, "rb") as file:
41
+ encoded_string = base64.b64encode(file.read()).decode('utf-8')
42
+ download_link = f'<a href="data:image/png;base64,{encoded_string}" download="{filename}">Download Image</a>'
43
+ return download_link
44
 
45
  def save_image(img, prompt):
46
+ global image_metadata, likes_cache
47
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
48
+ safe_prompt = re.sub(r'[^\w\s-]', '', prompt.lower())[:50] # Limit to 50 characters
49
  safe_prompt = re.sub(r'[-\s]+', '-', safe_prompt).strip('-')
50
  filename = f"{timestamp}_{safe_prompt}.png"
51
  img.save(filename)
 
55
  'Likes': [0],
56
  'Dislikes': [0],
57
  'Hearts': [0],
58
+ 'Created': [datetime.now()]
59
  })
60
+ image_metadata = pd.concat([image_metadata, new_row], ignore_index=True)
61
+ likes_cache[filename] = {'likes': 0, 'dislikes': 0, 'hearts': 0}
62
+ save_likes_cache(likes_cache)
63
  return filename
64
 
65
+ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
66
+ if randomize_seed:
67
+ seed = random.randint(0, MAX_SEED)
68
+ return seed
69
+
70
  def get_image_gallery():
71
+ global image_metadata
72
+ image_files = image_metadata['Filename'].tolist()
73
+ return [(file, get_image_caption(file)) for file in image_files if os.path.exists(file)]
74
 
75
  def get_image_caption(filename):
76
+ global likes_cache, image_metadata
77
+ if filename in likes_cache:
78
+ likes = likes_cache[filename]['likes']
79
+ dislikes = likes_cache[filename]['dislikes']
80
+ hearts = likes_cache[filename]['hearts']
81
+ prompt = image_metadata[image_metadata['Filename'] == filename]['Prompt'].values[0]
82
+ return f"{filename}\nPrompt: {prompt}\n👍 {likes} 👎 {dislikes} ❤️ {hearts}"
83
  return filename
84
 
85
  def delete_all_images():
86
+ global image_metadata, likes_cache
87
  for file in image_metadata['Filename']:
88
  if os.path.exists(file):
89
  os.remove(file)
90
  image_metadata = pd.DataFrame(columns=['Filename', 'Prompt', 'Likes', 'Dislikes', 'Hearts', 'Created'])
91
+ likes_cache = {}
92
+ save_likes_cache(likes_cache)
93
  return get_image_gallery(), image_metadata.values.tolist()
94
 
95
  def delete_image(filename):
96
+ global image_metadata, likes_cache
97
  if filename and os.path.exists(filename):
98
  os.remove(filename)
99
  image_metadata = image_metadata[image_metadata['Filename'] != filename]
100
+ if filename in likes_cache:
101
+ del likes_cache[filename]
102
+ save_likes_cache(likes_cache)
103
  return get_image_gallery(), image_metadata.values.tolist()
104
 
105
  def vote(filename, vote_type):
106
+ global likes_cache
107
+ if filename in likes_cache:
108
+ likes_cache[filename][vote_type.lower()] += 1
109
+ save_likes_cache(likes_cache)
 
 
 
110
  return get_image_gallery(), image_metadata.values.tolist()
111
 
112
+ def get_random_style():
113
+ styles = [
114
+ "Impressionist", "Cubist", "Surrealist", "Abstract Expressionist",
115
+ "Pop Art", "Minimalist", "Baroque", "Art Nouveau", "Pointillist", "Fauvism"
116
+ ]
117
+ return random.choice(styles)
118
+
119
+ MAX_SEED = np.iinfo(np.int32).max
120
+
121
+ if not torch.cuda.is_available():
122
+ DESCRIPTION += "\n<p>Running on CPU 🥶 This demo may not work on CPU.</p>"
123
+
124
+ USE_TORCH_COMPILE = 0
125
+ ENABLE_CPU_OFFLOAD = 0
126
+
127
  if torch.cuda.is_available():
128
  pipe = StableDiffusionXLPipeline.from_pretrained(
129
  "fluently/Fluently-XL-v4",
 
131
  use_safetensors=True,
132
  )
133
  pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
134
+
135
  pipe.load_lora_weights("ehristoforu/dalle-3-xl-v2", weight_name="dalle-3-xl-lora-v2.safetensors", adapter_name="dalle")
136
  pipe.set_adapters("dalle")
137
+
138
  pipe.to("cuda")
 
 
139
 
140
+ @spaces.GPU(enable_queue=True)
141
  def generate(
142
+ prompt: str,
143
+ negative_prompt: str = "",
144
+ use_negative_prompt: bool = False,
145
+ seed: int = 0,
146
+ width: int = 1024,
147
+ height: int = 1024,
148
+ guidance_scale: float = 3,
149
+ randomize_seed: bool = False,
150
+ progress=gr.Progress(track_tqdm=True),
151
  ):
152
+ seed = int(randomize_seed_fn(seed, randomize_seed))
153
+
154
  if not use_negative_prompt:
155
  negative_prompt = ""
156
+
157
  images = pipe(
158
  prompt=prompt,
159
  negative_prompt=negative_prompt,
 
166
  output_type="pil",
167
  ).images
168
  image_paths = [save_image(img, prompt) for img in images]
169
+ download_links = [create_download_link(path) for path in image_paths]
170
+
171
+ return image_paths, seed, download_links, get_image_gallery(), image_metadata.values.tolist()
172
+
173
+ examples = [
174
+ f"{get_random_style()} painting of a majestic lighthouse on a rocky coast. Use bold brushstrokes and a vibrant color palette to capture the interplay of light and shadow as the lighthouse beam cuts through a stormy night sky.",
175
+ f"{get_random_style()} still life featuring a pair of vintage eyeglasses. Focus on the intricate details of the frames and lenses, using a warm color scheme to evoke a sense of nostalgia and wisdom.",
176
+ f"{get_random_style()} depiction of a rustic wooden stool in a sunlit artist's studio. Emphasize the texture of the wood and the interplay of light and shadow, using a mix of earthy tones and highlights.",
177
+ f"{get_random_style()} scene viewed through an ornate window frame. Contrast the intricate details of the window with a dreamy, soft-focus landscape beyond, using a palette that transitions from cool interior tones to warm exterior hues.",
178
+ f"{get_random_style()} close-up study of interlaced fingers. Use a monochromatic color scheme to emphasize the form and texture of the hands, with dramatic lighting to create depth and emotion.",
179
+ f"{get_random_style()} composition featuring a set of dice in motion. Capture the energy and randomness of the throw, using a dynamic color palette and blurred lines to convey movement.",
180
+ f"{get_random_style()} interpretation of heaven. Create an ethereal atmosphere with soft, billowing clouds and radiant light, using a palette of celestial blues, golds, and whites.",
181
+ f"{get_random_style()} portrayal of an ancient, mystical gate. Combine architectural details with elements of fantasy, using a rich, jewel-toned palette to create an air of mystery and magic.",
182
+ f"{get_random_style()} portrait of a curious cat. Focus on capturing the feline's expressive eyes and sleek form, using a mix of bold and subtle colors to bring out the cat's personality.",
183
+ f"{get_random_style()} abstract representation of toes in sand. Use textured brushstrokes to convey the feeling of warm sand, with a palette inspired by a sun-drenched beach."
184
+ ]
185
 
186
  css = '''
187
  .gradio-container{max-width: 1024px !important}
188
  h1{text-align:center}
189
+ footer {
190
+ visibility: hidden
191
+ }
192
  '''
193
 
194
  with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
 
196
 
197
  with gr.Tab("Generate Images"):
198
  with gr.Group():
199
+ with gr.Row():
200
+ prompt = gr.Text(
201
+ label="Prompt",
202
+ show_label=False,
203
+ max_lines=1,
204
+ placeholder="Enter your prompt",
205
+ container=False,
206
+ )
207
+ run_button = gr.Button("Run", scale=0)
208
+ result = gr.Gallery(label="Result", columns=1, preview=True, show_label=False)
209
  with gr.Accordion("Advanced options", open=False):
210
+ use_negative_prompt = gr.Checkbox(label="Use negative prompt", value=True)
211
+ negative_prompt = gr.Text(
212
+ label="Negative prompt",
213
+ lines=4,
214
+ max_lines=6,
215
+ value="""(deformed, distorted, disfigured:1.3), poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, (mutated hands and fingers:1.4), disconnected limbs, mutation, mutated, ugly, disgusting, blurry, amputation, (NSFW:1.25)""",
216
+ placeholder="Enter a negative prompt",
217
+ visible=True,
218
+ )
219
+ seed = gr.Slider(
220
+ label="Seed",
221
+ minimum=0,
222
+ maximum=MAX_SEED,
223
+ step=1,
224
+ value=0,
225
+ visible=True
226
+ )
227
  randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
228
+ with gr.Row(visible=True):
229
+ width = gr.Slider(
230
+ label="Width",
231
+ minimum=512,
232
+ maximum=2048,
233
+ step=8,
234
+ value=1920,
235
+ )
236
+ height = gr.Slider(
237
+ label="Height",
238
+ minimum=512,
239
+ maximum=2048,
240
+ step=8,
241
+ value=1080,
242
+ )
243
+ with gr.Row():
244
+ guidance_scale = gr.Slider(
245
+ label="Guidance Scale",
246
+ minimum=0.1,
247
+ maximum=20.0,
248
+ step=0.1,
249
+ value=20.0,
250
+ )
251
+
252
+ gr.Examples(
253
+ examples=examples,
254
+ inputs=prompt,
255
+ outputs=[result, seed],
256
+ fn=generate,
257
+ cache_examples=False,
258
+ )
259
 
260
+ with gr.Tab("Gallery and Voting"):
261
+ image_gallery = gr.Gallery(label="Generated Images", show_label=True, columns=4, height="auto")
262
+
263
  with gr.Row():
264
  like_button = gr.Button("👍 Like")
265
  dislike_button = gr.Button("👎 Dislike")
266
  heart_button = gr.Button("❤️ Heart")
267
+ delete_image_button = gr.Button("🗑️ Delete Selected Image")
268
+
269
+ selected_image = gr.State(None)
270
+
271
+ with gr.Tab("Metadata and Management"):
272
+ metadata_df = gr.Dataframe(
273
+ label="Image Metadata",
274
+ headers=["Filename", "Prompt", "Likes", "Dislikes", "Hearts", "Created"],
275
+ interactive=False
276
+ )
277
  delete_all_button = gr.Button("🗑️ Delete All Images")
278
 
279
+ use_negative_prompt.change(
280
+ fn=lambda x: gr.update(visible=x),
281
+ inputs=use_negative_prompt,
282
+ outputs=negative_prompt,
283
+ api_name=False,
284
+ )
285
 
286
+ delete_all_button.click(
287
+ fn=delete_all_images,
288
+ inputs=[],
289
+ outputs=[image_gallery, metadata_df],
290
  )
291
 
292
+ image_gallery.select(
293
+ fn=lambda evt: evt,
294
+ inputs=[],
295
+ outputs=[selected_image],
296
+ )
297
 
298
+ like_button.click(
299
+ fn=lambda x: vote(x, 'likes'),
300
+ inputs=[selected_image],
301
+ outputs=[image_gallery, metadata_df],
302
+ )
303
+
304
+ dislike_button.click(
305
+ fn=lambda x: vote(x, 'dislikes'),
306
+ inputs=[selected_image],
307
+ outputs=[image_gallery, metadata_df],
308
+ )
309
+
310
+ heart_button.click(
311
+ fn=lambda x: vote(x, 'hearts'),
312
+ inputs=[selected_image],
313
+ outputs=[image_gallery, metadata_df],
314
+ )
315
+
316
+ delete_image_button.click(
317
+ fn=delete_image,
318
+ inputs=[selected_image],
319
+ outputs=[image_gallery, metadata_df],
320
+ )
321
+
322
+ def update_gallery_and_metadata():
323
+ return gr.update(value=get_image_gallery()), gr.update(value=image_metadata.values.tolist())
324
+
325
+ gr.on(
326
+ triggers=[
327
+ prompt.submit,
328
+ negative_prompt.submit,
329
+ run_button.click,
330
+ ],
331
+ fn=generate,
332
+ inputs=[
333
+ prompt,
334
+ negative_prompt,
335
+ use_negative_prompt,
336
+ seed,
337
+ width,
338
+ height,
339
+ guidance_scale,
340
+ randomize_seed,
341
+ ],
342
+ outputs=[result, seed, gr.HTML(visible=False), image_gallery, metadata_df],
343
+ api_name="run",
344
+ )
345
 
346
+ demo.load(fn=update_gallery_and_metadata, outputs=[image_gallery, metadata_df])
347
 
348
  if __name__ == "__main__":
349
+ demo.queue(max_size=20).launch(share=True, debug=False)