Update app.py

#2
by linoyts HF Staff - opened
Files changed (1) hide show
  1. app.py +18 -16
app.py CHANGED
@@ -78,7 +78,7 @@ def load_lora_weights(repo_id, weights_filename):
78
  def update_selection(selected_state: gr.SelectData, flux_loras):
79
  """Update UI when a LoRA is selected"""
80
  if selected_state.index >= len(flux_loras):
81
- return "### No LoRA selected", gr.update(), selected_state
82
 
83
  lora_repo = flux_loras[selected_state.index]["repo"]
84
  trigger_word = flux_loras[selected_state.index]["trigger_word"]
@@ -86,7 +86,7 @@ def update_selection(selected_state: gr.SelectData, flux_loras):
86
  updated_text = f"### Selected: [{lora_repo}](https://huggingface.co/{lora_repo})"
87
  new_placeholder = f"Enter your editing prompt{f' (use {trigger_word} for best results)' if trigger_word else ''}"
88
 
89
- return updated_text, gr.update(placeholder=new_placeholder), selected_state
90
 
91
  def get_huggingface_lora(link):
92
  """Download LoRA from HuggingFace link"""
@@ -117,7 +117,7 @@ def get_huggingface_lora(link):
117
  def load_custom_lora(link):
118
  """Load custom LoRA from user input"""
119
  if not link:
120
- return gr.update(visible=False), "", gr.update(visible=False), None, gr.Gallery(selected_index=None), "### Click on a LoRA in the gallery to select it"
121
 
122
  try:
123
  repo_name, weights_file, trigger_word = get_huggingface_lora(link)
@@ -138,22 +138,26 @@ def load_custom_lora(link):
138
  "trigger_word": trigger_word
139
  }
140
 
141
- return gr.update(visible=True), card, gr.update(visible=True), custom_lora_data, gr.Gallery(selected_index=None), f"Custom: {repo_name}"
142
 
143
  except Exception as e:
144
- return gr.update(visible=True), f"Error: {str(e)}", gr.update(visible=False), None, gr.update(), "### Click on a LoRA in the gallery to select it"
145
 
146
  def remove_custom_lora():
147
  """Remove custom LoRA"""
148
- return "", gr.update(visible=False), gr.update(visible=False), None
149
 
150
  def classify_gallery(flux_loras):
151
  """Sort gallery by likes"""
152
  sorted_gallery = sorted(flux_loras, key=lambda x: x.get("likes", 0), reverse=True)
153
  return [(item["image"], item["title"]) for item in sorted_gallery], sorted_gallery
154
 
 
 
 
 
155
  @spaces.GPU
156
- def infer_with_lora(input_image, prompt, selected_state, custom_lora, seed=42, randomize_seed=False, guidance_scale=2.5, lora_scale=1.0, flux_loras=None, progress=gr.Progress(track_tqdm=True)):
157
  """Generate image with selected LoRA"""
158
  global current_lora, pipe
159
 
@@ -164,10 +168,8 @@ def infer_with_lora(input_image, prompt, selected_state, custom_lora, seed=42, r
164
  lora_to_use = None
165
  if custom_lora:
166
  lora_to_use = custom_lora
167
- elif selected_state and flux_loras:
168
- selected_index = selected_state.index if hasattr(selected_state, 'index') else None
169
- if selected_index is not None and selected_index < len(flux_loras):
170
- lora_to_use = flux_loras[selected_index]
171
 
172
  # Load LoRA if needed
173
  if lora_to_use and lora_to_use != current_lora:
@@ -265,8 +267,8 @@ with gr.Blocks(css=css) as demo:
265
  <br><small style="font-size: 13px; opacity: 0.75;"></small></h1>""",
266
  )
267
 
268
- selected_state = gr.State()
269
- custom_loaded_lora = gr.State()
270
 
271
  with gr.Row(elem_id="main_app"):
272
  with gr.Column(scale=4, elem_id="box_column"):
@@ -340,12 +342,12 @@ with gr.Blocks(css=css) as demo:
340
  custom_model.input(
341
  fn=load_custom_lora,
342
  inputs=[custom_model],
343
- outputs=[custom_model_card, custom_model_card, custom_model_button, custom_loaded_lora, gallery, prompt_title],
344
  )
345
 
346
  custom_model_button.click(
347
  fn=remove_custom_lora,
348
- outputs=[custom_model, custom_model_button, custom_model_card, custom_loaded_lora]
349
  )
350
 
351
  gallery.select(
@@ -357,7 +359,7 @@ with gr.Blocks(css=css) as demo:
357
 
358
  gr.on(
359
  triggers=[run_button.click, prompt.submit],
360
- fn=infer_with_lora,
361
  inputs=[input_image, prompt, selected_state, custom_loaded_lora, seed, randomize_seed, guidance_scale, lora_scale, gr_flux_loras],
362
  outputs=[result, seed, reuse_button]
363
  )
 
78
  def update_selection(selected_state: gr.SelectData, flux_loras):
79
  """Update UI when a LoRA is selected"""
80
  if selected_state.index >= len(flux_loras):
81
+ return "### No LoRA selected", gr.update(), None
82
 
83
  lora_repo = flux_loras[selected_state.index]["repo"]
84
  trigger_word = flux_loras[selected_state.index]["trigger_word"]
 
86
  updated_text = f"### Selected: [{lora_repo}](https://huggingface.co/{lora_repo})"
87
  new_placeholder = f"Enter your editing prompt{f' (use {trigger_word} for best results)' if trigger_word else ''}"
88
 
89
+ return updated_text, gr.update(placeholder=new_placeholder), selected_state.index
90
 
91
  def get_huggingface_lora(link):
92
  """Download LoRA from HuggingFace link"""
 
117
  def load_custom_lora(link):
118
  """Load custom LoRA from user input"""
119
  if not link:
120
+ return gr.update(visible=False), "", gr.update(visible=False), None, gr.Gallery(selected_index=None), "### Click on a LoRA in the gallery to select it", None
121
 
122
  try:
123
  repo_name, weights_file, trigger_word = get_huggingface_lora(link)
 
138
  "trigger_word": trigger_word
139
  }
140
 
141
+ return gr.update(visible=True), card, gr.update(visible=True), custom_lora_data, gr.Gallery(selected_index=None), f"Custom: {repo_name}", None
142
 
143
  except Exception as e:
144
+ return gr.update(visible=True), f"Error: {str(e)}", gr.update(visible=False), None, gr.update(), "### Click on a LoRA in the gallery to select it", None
145
 
146
  def remove_custom_lora():
147
  """Remove custom LoRA"""
148
+ return "", gr.update(visible=False), gr.update(visible=False), None, None
149
 
150
  def classify_gallery(flux_loras):
151
  """Sort gallery by likes"""
152
  sorted_gallery = sorted(flux_loras, key=lambda x: x.get("likes", 0), reverse=True)
153
  return [(item["image"], item["title"]) for item in sorted_gallery], sorted_gallery
154
 
155
+ def infer_with_lora_wrapper(input_image, prompt, selected_index, custom_lora, seed=42, randomize_seed=False, guidance_scale=2.5, lora_scale=1.0, flux_loras=None, progress=gr.Progress(track_tqdm=True)):
156
+ """Wrapper function to handle state serialization"""
157
+ return infer_with_lora(input_image, prompt, selected_index, custom_lora, seed, randomize_seed, guidance_scale, lora_scale, flux_loras, progress)
158
+
159
  @spaces.GPU
160
+ def infer_with_lora(input_image, prompt, selected_index, custom_lora, seed=42, randomize_seed=False, guidance_scale=2.5, lora_scale=1.0, flux_loras=None, progress=gr.Progress(track_tqdm=True)):
161
  """Generate image with selected LoRA"""
162
  global current_lora, pipe
163
 
 
168
  lora_to_use = None
169
  if custom_lora:
170
  lora_to_use = custom_lora
171
+ elif selected_index is not None and flux_loras and selected_index < len(flux_loras):
172
+ lora_to_use = flux_loras[selected_index]
 
 
173
 
174
  # Load LoRA if needed
175
  if lora_to_use and lora_to_use != current_lora:
 
267
  <br><small style="font-size: 13px; opacity: 0.75;"></small></h1>""",
268
  )
269
 
270
+ selected_state = gr.State(value=None)
271
+ custom_loaded_lora = gr.State(value=None)
272
 
273
  with gr.Row(elem_id="main_app"):
274
  with gr.Column(scale=4, elem_id="box_column"):
 
342
  custom_model.input(
343
  fn=load_custom_lora,
344
  inputs=[custom_model],
345
+ outputs=[custom_model_card, custom_model_card, custom_model_button, custom_loaded_lora, gallery, prompt_title, selected_state],
346
  )
347
 
348
  custom_model_button.click(
349
  fn=remove_custom_lora,
350
+ outputs=[custom_model, custom_model_button, custom_model_card, custom_loaded_lora, selected_state]
351
  )
352
 
353
  gallery.select(
 
359
 
360
  gr.on(
361
  triggers=[run_button.click, prompt.submit],
362
+ fn=infer_with_lora_wrapper,
363
  inputs=[input_image, prompt, selected_state, custom_loaded_lora, seed, randomize_seed, guidance_scale, lora_scale, gr_flux_loras],
364
  outputs=[result, seed, reuse_button]
365
  )