Sergidev commited on
Commit
02a5208
·
verified ·
1 Parent(s): c470655
Files changed (1) hide show
  1. app.py +62 -19
app.py CHANGED
@@ -1,5 +1,6 @@
1
  import os
2
  import gc
 
3
  import gradio as gr
4
  import numpy as np
5
  import torch
@@ -38,7 +39,6 @@ torch.backends.cudnn.benchmark = False
38
 
39
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
40
 
41
-
42
  def load_pipeline(model_name):
43
  vae = AutoencoderKL.from_pretrained(
44
  "madebyollin/sdxl-vae-fp16-fix",
@@ -64,8 +64,32 @@ def load_pipeline(model_name):
64
  pipe.to(device)
65
  return pipe
66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
- @spaces.GPU
69
  def generate(
70
  prompt: str,
71
  negative_prompt: str = "",
@@ -159,7 +183,12 @@ def generate(
159
  filepath = utils.save_image(image, metadata, OUTPUT_DIR)
160
  logger.info(f"Image saved as {filepath} with metadata")
161
 
162
- return images, metadata
 
 
 
 
 
163
  except Exception as e:
164
  logger.exception(f"An error occurred: {e}")
165
  raise
@@ -169,6 +198,21 @@ def generate(
169
  pipe.scheduler = backup_scheduler
170
  utils.free_memory()
171
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
 
173
  if torch.cuda.is_available():
174
  pipe = load_pipeline(MODEL)
@@ -285,8 +329,19 @@ with gr.Blocks(css="style.css") as demo:
285
  step=1,
286
  value=28,
287
  )
 
 
 
 
 
 
 
 
 
 
288
  with gr.Accordion(label="Generation Parameters", open=False):
289
  gr_metadata = gr.JSON(label="Metadata", show_label=False)
 
290
  gr.Examples(
291
  examples=config.examples,
292
  inputs=prompt,
@@ -294,6 +349,7 @@ with gr.Blocks(css="style.css") as demo:
294
  fn=lambda *args, **kwargs: generate(*args, use_upscaler=True, **kwargs),
295
  cache_examples=CACHE_EXAMPLES,
296
  )
 
297
  use_upscaler.change(
298
  fn=lambda x: [gr.update(visible=x), gr.update(visible=x)],
299
  inputs=use_upscaler,
@@ -333,7 +389,7 @@ with gr.Blocks(css="style.css") as demo:
333
  ).then(
334
  fn=generate,
335
  inputs=inputs,
336
- outputs=result,
337
  api_name="run",
338
  )
339
  negative_prompt.submit(
@@ -345,19 +401,6 @@ with gr.Blocks(css="style.css") as demo:
345
  ).then(
346
  fn=generate,
347
  inputs=inputs,
348
- outputs=result,
349
- api_name=False,
350
- )
351
- run_button.click(
352
- fn=utils.randomize_seed_fn,
353
- inputs=[seed, randomize_seed],
354
- outputs=seed,
355
- queue=False,
356
- api_name=False,
357
- ).then(
358
- fn=generate,
359
- inputs=inputs,
360
- outputs=[result, gr_metadata],
361
  api_name=False,
362
- )
363
- demo.queue(max_size=20).launch(debug=IS_COLAB, share=IS_COLAB)
 
1
  import os
2
  import gc
3
+ import random
4
  import gradio as gr
5
  import numpy as np
6
  import torch
 
39
 
40
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
41
 
 
42
  def load_pipeline(model_name):
43
  vae = AutoencoderKL.from_pretrained(
44
  "madebyollin/sdxl-vae-fp16-fix",
 
64
  pipe.to(device)
65
  return pipe
66
 
67
+ def parse_json_parameters(json_str):
68
+ try:
69
+ params = json.loads(json_str)
70
+ return params
71
+ except json.JSONDecodeError:
72
+ return None
73
+
74
+ def apply_json_parameters(json_str):
75
+ params = parse_json_parameters(json_str)
76
+ if params:
77
+ return (
78
+ params.get("prompt", ""),
79
+ params.get("negative_prompt", ""),
80
+ params.get("seed", 0),
81
+ params.get("width", 1024),
82
+ params.get("height", 1024),
83
+ params.get("guidance_scale", 7.0),
84
+ params.get("num_inference_steps", 30),
85
+ params.get("sampler", "DPM++ 2M SDE Karras"),
86
+ params.get("aspect_ratio", "1024 x 1024"),
87
+ params.get("use_upscaler", False),
88
+ params.get("upscaler_strength", 0.55),
89
+ params.get("upscale_by", 1.5),
90
+ )
91
+ return [gr.update()] * 12
92
 
 
93
  def generate(
94
  prompt: str,
95
  negative_prompt: str = "",
 
183
  filepath = utils.save_image(image, metadata, OUTPUT_DIR)
184
  logger.info(f"Image saved as {filepath} with metadata")
185
 
186
+ # Update history after generation
187
+ history = gr.get_state("history") or []
188
+ history.insert(0, {"prompt": prompt, "image": images[0], "metadata": metadata})
189
+ gr.set_state("history", history[:10]) # Keep only the last 10 entries
190
+
191
+ return images, metadata, gr.update(choices=[h["prompt"] for h in history])
192
  except Exception as e:
193
  logger.exception(f"An error occurred: {e}")
194
  raise
 
198
  pipe.scheduler = backup_scheduler
199
  utils.free_memory()
200
 
201
+ def get_random_prompt():
202
+ anime_characters = [
203
+ "Naruto Uzumaki", "Monkey D. Luffy", "Goku", "Eren Yeager", "Light Yagami",
204
+ "Lelouch Lamperouge", "Edward Elric", "Levi Ackerman", "Spike Spiegel",
205
+ "Sakura Haruno", "Mikasa Ackerman", "Asuka Langley Soryu", "Rem", "Megumin",
206
+ "Violet Evergarden"
207
+ ]
208
+ styles = ["pixel art", "stylized anime", "digital art", "watercolor", "sketch"]
209
+ scores = ["score_9", "score_8_up", "score_7_up"]
210
+
211
+ character = random.choice(anime_characters)
212
+ style = random.choice(styles)
213
+ score = ", ".join(random.sample(scores, k=3))
214
+
215
+ return f"{score}, {character}, {style}, show accurate"
216
 
217
  if torch.cuda.is_available():
218
  pipe = load_pipeline(MODEL)
 
329
  step=1,
330
  value=28,
331
  )
332
+ with gr.Accordion(label="JSON Parameters", open=False):
333
+ json_input = gr.TextArea(label="Input JSON parameters")
334
+ apply_json_button = gr.Button("Apply JSON Parameters")
335
+
336
+ with gr.Row():
337
+ clear_button = gr.Button("Clear All")
338
+ random_prompt_button = gr.Button("Random Prompt")
339
+
340
+ history_dropdown = gr.Dropdown(label="Generation History", choices=[], interactive=True)
341
+
342
  with gr.Accordion(label="Generation Parameters", open=False):
343
  gr_metadata = gr.JSON(label="Metadata", show_label=False)
344
+
345
  gr.Examples(
346
  examples=config.examples,
347
  inputs=prompt,
 
349
  fn=lambda *args, **kwargs: generate(*args, use_upscaler=True, **kwargs),
350
  cache_examples=CACHE_EXAMPLES,
351
  )
352
+
353
  use_upscaler.change(
354
  fn=lambda x: [gr.update(visible=x), gr.update(visible=x)],
355
  inputs=use_upscaler,
 
389
  ).then(
390
  fn=generate,
391
  inputs=inputs,
392
+ outputs=[result, gr_metadata, history_dropdown],
393
  api_name="run",
394
  )
395
  negative_prompt.submit(
 
401
  ).then(
402
  fn=generate,
403
  inputs=inputs,
404
+ outputs=[result, gr_metadata, history_dropdown],
 
 
 
 
 
 
 
 
 
 
 
 
405
  api_name=False,
406
+ )