Spaces:
Runtime error
Runtime error
Shaamallow
commited on
Commit
β’
d8ef2e5
1
Parent(s):
7d4c917
fix LoRA swap on Zero GPU
Browse files
app.py
CHANGED
@@ -39,25 +39,6 @@ MAX_SEED = np.iinfo(np.int32).max
|
|
39 |
MAX_IMAGE_SIZE = 1024
|
40 |
|
41 |
|
42 |
-
def check_and_load_lora_user(user_lora_selector, user_lora_weight, gr_lora_loaded):
|
43 |
-
flash_sdxl_id = "jasperai/flash-sdxl"
|
44 |
-
|
45 |
-
if user_lora_selector == "" or user_lora_selector == "":
|
46 |
-
raise gr.Error("Please select a LoRA before running the inference.")
|
47 |
-
|
48 |
-
if gr_lora_loaded != user_lora_selector:
|
49 |
-
gr.Info("Loading LoRA")
|
50 |
-
pipe.unload_lora_weights()
|
51 |
-
pipe.load_lora_weights(flash_sdxl_id, adapter_name="lora")
|
52 |
-
pipe.load_lora_weights(user_lora_selector, adapter_name="user")
|
53 |
-
pipe.set_adapters(["lora", "user"], adapter_weights=[1.0, user_lora_weight])
|
54 |
-
gr.Info("LoRA Loaded")
|
55 |
-
|
56 |
-
gr_lora_loaded = user_lora_selector
|
57 |
-
|
58 |
-
return gr_lora_loaded
|
59 |
-
|
60 |
-
|
61 |
def rescale_lora(user_lora_weight):
|
62 |
|
63 |
global pipe
|
@@ -86,22 +67,26 @@ def infer(
|
|
86 |
guidance_scale,
|
87 |
user_lora_selector,
|
88 |
user_lora_weight,
|
89 |
-
|
90 |
):
|
91 |
flash_sdxl_id = "jasperai/flash-sdxl"
|
92 |
|
93 |
-
|
94 |
-
raise gr.Error("Please select a LoRA before running the inference.")
|
95 |
|
96 |
-
|
97 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
98 |
pipe.unload_lora_weights()
|
99 |
pipe.load_lora_weights(flash_sdxl_id, adapter_name="lora")
|
100 |
-
pipe.load_lora_weights(user_lora_selector, adapter_name=
|
101 |
-
pipe.set_adapters(["lora", "user"], adapter_weights=[1.0, user_lora_weight])
|
102 |
-
gr.Info("LoRA Loaded")
|
103 |
|
104 |
-
|
|
|
105 |
|
106 |
if randomize_seed:
|
107 |
seed = random.randint(0, MAX_SEED)
|
@@ -111,6 +96,8 @@ def infer(
|
|
111 |
if pre_prompt != "":
|
112 |
prompt = f"{pre_prompt} {prompt}"
|
113 |
|
|
|
|
|
114 |
image = pipe(
|
115 |
prompt=prompt,
|
116 |
negative_prompt=negative_prompt,
|
@@ -119,7 +106,7 @@ def infer(
|
|
119 |
generator=generator,
|
120 |
).images[0]
|
121 |
|
122 |
-
return image
|
123 |
|
124 |
|
125 |
css = """
|
@@ -160,7 +147,6 @@ with gr.Blocks(css=css) as demo:
|
|
160 |
# Index of selected LoRA
|
161 |
gr_sdxl_loras = gr.State(value=sdxl_loras_raw)
|
162 |
# Serve as memory for currently loaded lora in pipe
|
163 |
-
gr_lora_loaded = gr.State(value="")
|
164 |
gr_lora_id = gr.State(value="")
|
165 |
|
166 |
with gr.Row():
|
@@ -285,11 +271,10 @@ with gr.Blocks(css=css) as demo:
|
|
285 |
negative_prompt,
|
286 |
guidance_scale,
|
287 |
user_lora_selector,
|
288 |
-
user_lora_weight
|
289 |
-
gr_lora_loaded,
|
290 |
],
|
291 |
-
outputs=[result
|
292 |
-
show_progress="
|
293 |
)
|
294 |
|
295 |
user_lora_weight.change(
|
|
|
39 |
MAX_IMAGE_SIZE = 1024
|
40 |
|
41 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
def rescale_lora(user_lora_weight):
|
43 |
|
44 |
global pipe
|
|
|
67 |
guidance_scale,
|
68 |
user_lora_selector,
|
69 |
user_lora_weight,
|
70 |
+
progress=gr.Progress(track_tqdm=True)
|
71 |
):
|
72 |
flash_sdxl_id = "jasperai/flash-sdxl"
|
73 |
|
74 |
+
gr.Info("Checking LoRA")
|
|
|
75 |
|
76 |
+
new_adapter_id = user_lora_selector.replace("/", "_")
|
77 |
+
loaded_adapters = pipe.get_list_adapters()
|
78 |
+
|
79 |
+
print(loaded_adapters["unet"])
|
80 |
+
print(new_adapter_id)
|
81 |
+
|
82 |
+
if new_adapter_id not in loaded_adapters["unet"]:
|
83 |
+
gr.Info("Swapping LoRA")
|
84 |
pipe.unload_lora_weights()
|
85 |
pipe.load_lora_weights(flash_sdxl_id, adapter_name="lora")
|
86 |
+
pipe.load_lora_weights(user_lora_selector, adapter_name=new_adapter_id)
|
|
|
|
|
87 |
|
88 |
+
pipe.set_adapters(["lora", new_adapter_id], adapter_weights=[1.0, user_lora_weight])
|
89 |
+
gr.Info("LoRA setup done")
|
90 |
|
91 |
if randomize_seed:
|
92 |
seed = random.randint(0, MAX_SEED)
|
|
|
96 |
if pre_prompt != "":
|
97 |
prompt = f"{pre_prompt} {prompt}"
|
98 |
|
99 |
+
gr.Info("Generation Stage")
|
100 |
+
|
101 |
image = pipe(
|
102 |
prompt=prompt,
|
103 |
negative_prompt=negative_prompt,
|
|
|
106 |
generator=generator,
|
107 |
).images[0]
|
108 |
|
109 |
+
return image
|
110 |
|
111 |
|
112 |
css = """
|
|
|
147 |
# Index of selected LoRA
|
148 |
gr_sdxl_loras = gr.State(value=sdxl_loras_raw)
|
149 |
# Serve as memory for currently loaded lora in pipe
|
|
|
150 |
gr_lora_id = gr.State(value="")
|
151 |
|
152 |
with gr.Row():
|
|
|
271 |
negative_prompt,
|
272 |
guidance_scale,
|
273 |
user_lora_selector,
|
274 |
+
user_lora_weight
|
|
|
275 |
],
|
276 |
+
outputs=[result],
|
277 |
+
# show_progress="full",
|
278 |
)
|
279 |
|
280 |
user_lora_weight.change(
|