Shaamallow commited on
Commit
d8ef2e5
β€’
1 Parent(s): 7d4c917

fix LoRA swap on Zero GPU

Browse files
Files changed (1) hide show
  1. app.py +19 -34
app.py CHANGED
@@ -39,25 +39,6 @@ MAX_SEED = np.iinfo(np.int32).max
39
  MAX_IMAGE_SIZE = 1024
40
 
41
 
42
- def check_and_load_lora_user(user_lora_selector, user_lora_weight, gr_lora_loaded):
43
- flash_sdxl_id = "jasperai/flash-sdxl"
44
-
45
- if user_lora_selector == "" or user_lora_selector == "":
46
- raise gr.Error("Please select a LoRA before running the inference.")
47
-
48
- if gr_lora_loaded != user_lora_selector:
49
- gr.Info("Loading LoRA")
50
- pipe.unload_lora_weights()
51
- pipe.load_lora_weights(flash_sdxl_id, adapter_name="lora")
52
- pipe.load_lora_weights(user_lora_selector, adapter_name="user")
53
- pipe.set_adapters(["lora", "user"], adapter_weights=[1.0, user_lora_weight])
54
- gr.Info("LoRA Loaded")
55
-
56
- gr_lora_loaded = user_lora_selector
57
-
58
- return gr_lora_loaded
59
-
60
-
61
  def rescale_lora(user_lora_weight):
62
 
63
  global pipe
@@ -86,22 +67,26 @@ def infer(
86
  guidance_scale,
87
  user_lora_selector,
88
  user_lora_weight,
89
- gr_lora_loaded,
90
  ):
91
  flash_sdxl_id = "jasperai/flash-sdxl"
92
 
93
- if user_lora_selector == "" or user_lora_selector == "":
94
- raise gr.Error("Please select a LoRA before running the inference.")
95
 
96
- if gr_lora_loaded != user_lora_selector:
97
- gr.Info("Loading LoRA")
 
 
 
 
 
 
98
  pipe.unload_lora_weights()
99
  pipe.load_lora_weights(flash_sdxl_id, adapter_name="lora")
100
- pipe.load_lora_weights(user_lora_selector, adapter_name="user")
101
- pipe.set_adapters(["lora", "user"], adapter_weights=[1.0, user_lora_weight])
102
- gr.Info("LoRA Loaded")
103
 
104
- gr_lora_loaded = user_lora_selector
 
105
 
106
  if randomize_seed:
107
  seed = random.randint(0, MAX_SEED)
@@ -111,6 +96,8 @@ def infer(
111
  if pre_prompt != "":
112
  prompt = f"{pre_prompt} {prompt}"
113
 
 
 
114
  image = pipe(
115
  prompt=prompt,
116
  negative_prompt=negative_prompt,
@@ -119,7 +106,7 @@ def infer(
119
  generator=generator,
120
  ).images[0]
121
 
122
- return image, gr_lora_loaded
123
 
124
 
125
  css = """
@@ -160,7 +147,6 @@ with gr.Blocks(css=css) as demo:
160
  # Index of selected LoRA
161
  gr_sdxl_loras = gr.State(value=sdxl_loras_raw)
162
  # Serve as memory for currently loaded lora in pipe
163
- gr_lora_loaded = gr.State(value="")
164
  gr_lora_id = gr.State(value="")
165
 
166
  with gr.Row():
@@ -285,11 +271,10 @@ with gr.Blocks(css=css) as demo:
285
  negative_prompt,
286
  guidance_scale,
287
  user_lora_selector,
288
- user_lora_weight,
289
- gr_lora_loaded,
290
  ],
291
- outputs=[result, gr_lora_loaded],
292
- show_progress="minimal",
293
  )
294
 
295
  user_lora_weight.change(
 
39
  MAX_IMAGE_SIZE = 1024
40
 
41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  def rescale_lora(user_lora_weight):
43
 
44
  global pipe
 
67
  guidance_scale,
68
  user_lora_selector,
69
  user_lora_weight,
70
+ progress=gr.Progress(track_tqdm=True)
71
  ):
72
  flash_sdxl_id = "jasperai/flash-sdxl"
73
 
74
+ gr.Info("Checking LoRA")
 
75
 
76
+ new_adapter_id = user_lora_selector.replace("/", "_")
77
+ loaded_adapters = pipe.get_list_adapters()
78
+
79
+ print(loaded_adapters["unet"])
80
+ print(new_adapter_id)
81
+
82
+ if new_adapter_id not in loaded_adapters["unet"]:
83
+ gr.Info("Swapping LoRA")
84
  pipe.unload_lora_weights()
85
  pipe.load_lora_weights(flash_sdxl_id, adapter_name="lora")
86
+ pipe.load_lora_weights(user_lora_selector, adapter_name=new_adapter_id)
 
 
87
 
88
+ pipe.set_adapters(["lora", new_adapter_id], adapter_weights=[1.0, user_lora_weight])
89
+ gr.Info("LoRA setup done")
90
 
91
  if randomize_seed:
92
  seed = random.randint(0, MAX_SEED)
 
96
  if pre_prompt != "":
97
  prompt = f"{pre_prompt} {prompt}"
98
 
99
+ gr.Info("Generation Stage")
100
+
101
  image = pipe(
102
  prompt=prompt,
103
  negative_prompt=negative_prompt,
 
106
  generator=generator,
107
  ).images[0]
108
 
109
+ return image
110
 
111
 
112
  css = """
 
147
  # Index of selected LoRA
148
  gr_sdxl_loras = gr.State(value=sdxl_loras_raw)
149
  # Serve as memory for currently loaded lora in pipe
 
150
  gr_lora_id = gr.State(value="")
151
 
152
  with gr.Row():
 
271
  negative_prompt,
272
  guidance_scale,
273
  user_lora_selector,
274
+ user_lora_weight
 
275
  ],
276
+ outputs=[result],
277
+ # show_progress="full",
278
  )
279
 
280
  user_lora_weight.change(