eyal.benaroche commited on
Commit
7d4c917
β€’
1 Parent(s): 29a59f2

simplify logic

Browse files
Files changed (1) hide show
  1. app.py +22 -6
app.py CHANGED
@@ -84,7 +84,24 @@ def infer(
84
  num_inference_steps,
85
  negative_prompt,
86
  guidance_scale,
 
 
 
87
  ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
 
89
  if randomize_seed:
90
  seed = random.randint(0, MAX_SEED)
@@ -102,7 +119,7 @@ def infer(
102
  generator=generator,
103
  ).images[0]
104
 
105
- return image
106
 
107
 
108
  css = """
@@ -258,10 +275,6 @@ with gr.Blocks(css=css) as demo:
258
  negative_prompt.submit,
259
  guidance_scale.change,
260
  ],
261
- fn=check_and_load_lora_user,
262
- inputs=[user_lora_selector, user_lora_weight, gr_lora_loaded],
263
- outputs=[gr_lora_loaded],
264
- ).success(
265
  fn=infer,
266
  inputs=[
267
  pre_prompt,
@@ -271,8 +284,11 @@ with gr.Blocks(css=css) as demo:
271
  num_inference_steps,
272
  negative_prompt,
273
  guidance_scale,
 
 
 
274
  ],
275
- outputs=[result],
276
  show_progress="minimal",
277
  )
278
 
 
84
  num_inference_steps,
85
  negative_prompt,
86
  guidance_scale,
87
+ user_lora_selector,
88
+ user_lora_weight,
89
+ gr_lora_loaded,
90
  ):
91
+ flash_sdxl_id = "jasperai/flash-sdxl"
92
+
93
+ if user_lora_selector == "" or user_lora_selector == "":
94
+ raise gr.Error("Please select a LoRA before running the inference.")
95
+
96
+ if gr_lora_loaded != user_lora_selector:
97
+ gr.Info("Loading LoRA")
98
+ pipe.unload_lora_weights()
99
+ pipe.load_lora_weights(flash_sdxl_id, adapter_name="lora")
100
+ pipe.load_lora_weights(user_lora_selector, adapter_name="user")
101
+ pipe.set_adapters(["lora", "user"], adapter_weights=[1.0, user_lora_weight])
102
+ gr.Info("LoRA Loaded")
103
+
104
+ gr_lora_loaded = user_lora_selector
105
 
106
  if randomize_seed:
107
  seed = random.randint(0, MAX_SEED)
 
119
  generator=generator,
120
  ).images[0]
121
 
122
+ return image, gr_lora_loaded
123
 
124
 
125
  css = """
 
275
  negative_prompt.submit,
276
  guidance_scale.change,
277
  ],
 
 
 
 
278
  fn=infer,
279
  inputs=[
280
  pre_prompt,
 
284
  num_inference_steps,
285
  negative_prompt,
286
  guidance_scale,
287
+ user_lora_selector,
288
+ user_lora_weight,
289
+ gr_lora_loaded,
290
  ],
291
+ outputs=[result, gr_lora_loaded],
292
  show_progress="minimal",
293
  )
294