TheAwakenOne commited on
Commit
e777edc
·
1 Parent(s): 0b25f63
Files changed (1) hide show
  1. app.py +62 -170
app.py CHANGED
@@ -38,7 +38,7 @@ class calculateDuration:
38
  def __enter__(self):
39
  self.start_time = time.time()
40
  return self
41
-
42
  def __exit__(self, exc_type, exc_value, traceback):
43
  self.end_time = time.time()
44
  self.elapsed_time = self.end_time - self.start_time
@@ -106,15 +106,17 @@ def generate_image_to_image(prompt_mash, image_input_path, image_strength, steps
106
  joint_attention_kwargs={"scale": lora_scale},
107
  output_type="pil",
108
  ).images[0]
109
- return final_image
110
 
111
  def run_lora(prompt, image_input, image_strength, cfg_scale, steps, selected_index, randomize_seed, seed, width, height, lora_scale, progress=gr.Progress(track_tqdm=True)):
112
  if selected_index is None:
113
  raise gr.Error("You must select a LoRA before proceeding.")
 
114
  selected_lora = loras[selected_index]
115
  lora_path = selected_lora["repo"]
116
  trigger_word = selected_lora["trigger_word"]
117
- if(trigger_word):
 
118
  if "trigger_position" in selected_lora:
119
  if selected_lora["trigger_position"] == "prepend":
120
  prompt_mash = f"{trigger_word} {prompt}"
@@ -128,10 +130,10 @@ def run_lora(prompt, image_input, image_strength, cfg_scale, steps, selected_ind
128
  with calculateDuration("Unloading LoRA"):
129
  pipe.unload_lora_weights()
130
  pipe_i2i.unload_lora_weights()
131
-
132
  # Load LoRA weights
133
  with calculateDuration(f"Loading LoRA weights for {selected_lora['title']}"):
134
- if(image_input is not None):
135
  if "weights" in selected_lora:
136
  pipe_i2i.load_lora_weights(lora_path, weight_name=selected_lora["weights"])
137
  else:
@@ -141,194 +143,84 @@ def run_lora(prompt, image_input, image_strength, cfg_scale, steps, selected_ind
141
  pipe.load_lora_weights(lora_path, weight_name=selected_lora["weights"])
142
  else:
143
  pipe.load_lora_weights(lora_path)
144
-
145
  # Set random seed for reproducibility
146
  with calculateDuration("Randomizing seed"):
147
  if randomize_seed:
148
  seed = random.randint(0, MAX_SEED)
149
-
150
- if(image_input is not None):
151
-
152
  final_image = generate_image_to_image(prompt_mash, image_input, image_strength, steps, cfg_scale, width, height, lora_scale, seed)
153
  yield final_image, seed, gr.update(visible=False)
154
  else:
155
  image_generator = generate_image(prompt_mash, steps, seed, cfg_scale, width, height, lora_scale, progress)
156
-
157
  # Consume the generator to get the final image
158
  final_image = None
159
  step_counter = 0
160
  for image in image_generator:
161
- step_counter+=1
162
  final_image = image
163
- progress_bar = f'<div class="progress-container"><div class="progress-bar" style="--current: {step_counter}; --total: {steps};"></div></div>'
164
- yield image, seed, gr.update(value=progress_bar, visible=True)
165
-
166
- yield final_image, seed, gr.update(value=progress_bar, visible=False)
167
 
168
- def get_huggingface_safetensors(link):
169
- split_link = link.split("/")
170
- if(len(split_link) == 2):
171
- model_card = ModelCard.load(link)
172
- base_model = model_card.data.get("base_model")
173
- print(base_model)
174
- if((base_model != "black-forest-labs/FLUX.1-dev") and (base_model != "black-forest-labs/FLUX.1-schnell")):
175
- raise Exception("Not a FLUX LoRA!")
176
- image_path = model_card.data.get("widget", [{}])[0].get("output", {}).get("url", None)
177
- trigger_word = model_card.data.get("instance_prompt", "")
178
- image_url = f"https://huggingface.co/{link}/resolve/main/{image_path}" if image_path else None
179
- fs = HfFileSystem()
180
- try:
181
- list_of_files = fs.ls(link, detail=False)
182
- for file in list_of_files:
183
- if(file.endswith(".safetensors")):
184
- safetensors_name = file.split("/")[-1]
185
- if (not image_url and file.lower().endswith((".jpg", ".jpeg", ".png", ".webp"))):
186
- image_elements = file.split("/")
187
- image_url = f"https://huggingface.co/{link}/resolve/main/{image_elements[-1]}"
188
- except Exception as e:
189
- print(e)
190
- gr.Warning(f"You didn't include a link neither a valid Hugging Face repository with a *.safetensors LoRA")
191
- raise Exception(f"You didn't include a link neither a valid Hugging Face repository with a *.safetensors LoRA")
192
- return split_link[1], link, safetensors_name, trigger_word, image_url
193
-
194
- def check_custom_model(link):
195
- if(link.startswith("https://")):
196
- if(link.startswith("https://huggingface.co") or link.startswith("https://www.huggingface.co")):
197
- link_split = link.split("huggingface.co/")
198
- return get_huggingface_safetensors(link_split[1])
199
- else:
200
- return get_huggingface_safetensors(link)
201
 
202
- def add_custom_lora(custom_lora):
203
- global loras
204
- if(custom_lora):
205
- try:
206
- title, repo, path, trigger_word, image = check_custom_model(custom_lora)
207
- print(f"Loaded custom LoRA: {repo}")
208
- card = f'''
209
- <div class="custom_lora_card">
210
- <span>Loaded custom LoRA:</span>
211
- <div class="card_internal">
212
- <img src="{image}" />
213
- <div>
214
- <h3>{title}</h3>
215
- <small>{"Using: <code><b>"+trigger_word+"</code></b> as the trigger word" if trigger_word else "No trigger word found. If there's a trigger word, include it in your prompt"}<br></small>
216
- </div>
217
- </div>
218
- </div>
219
- '''
220
- existing_item_index = next((index for (index, item) in enumerate(loras) if item['repo'] == repo), None)
221
- if(not existing_item_index):
222
- new_item = {
223
- "image": image,
224
- "title": title,
225
- "repo": repo,
226
- "weights": path,
227
- "trigger_word": trigger_word
228
- }
229
- print(new_item)
230
- existing_item_index = len(loras)
231
- loras.append(new_item)
232
-
233
- return gr.update(visible=True, value=card), gr.update(visible=True), gr.Gallery(selected_index=None), f"Custom: {path}", existing_item_index, trigger_word
234
- except Exception as e:
235
- gr.Warning(f"Invalid LoRA: either you entered an invalid link, or a non-FLUX LoRA")
236
- return gr.update(visible=True, value=f"Invalid LoRA: either you entered an invalid link, a non-FLUX LoRA"), gr.update(visible=True), gr.update(), "", None, ""
237
- else:
238
- return gr.update(visible=False), gr.update(visible=False), gr.update(), "", None, ""
239
 
240
- def remove_custom_lora():
241
- return gr.update(visible=False), gr.update(visible=False), gr.update(), "", None, ""
 
 
 
 
 
242
 
243
- run_lora.zerogpu = True
 
 
244
 
245
- css = '''
246
- #gen_btn{height: 100%}
247
- #gen_column{align-self: stretch}
248
- #title{text-align: center}
249
- #title h1{font-size: 3em; display:inline-flex; align-items:center}
250
- #title img{width: 100px; margin-right: 0.5em}
251
- #gallery .grid-wrap{height: 10vh}
252
- #lora_list{background: var(--block-background-fill);padding: 0 1em .3em; font-size: 90%}
253
- .card_internal{display: flex;height: 100px;margin-top: .5em}
254
- .card_internal img{margin-right: 1em}
255
- .styler{--form-gap-width: 0px !important}
256
- #progress{height:30px}
257
- #progress .generating{display:none}
258
- .progress-container {width: 100%;height: 30px;background-color: #f0f0f0;border-radius: 15px;overflow: hidden;margin-bottom: 20px}
259
- .progress-bar {height: 100%;background-color: #4f46e5;width: calc(var(--current) / var(--total) * 100%);transition: width 0.5s ease-in-out}
260
- '''
261
- font=[gr.themes.GoogleFont("Source Sans Pro"), "Arial", "sans-serif"]
262
- with gr.Blocks(theme=gr.themes.Soft(font=font), css=css, delete_cache=(60, 3600)) as app:
263
- title = gr.HTML(
264
- """<h1><img src="https://huggingface.co/spaces/multimodalart/flux-lora-the-explorer/resolve/main/flux_lora.png" alt="LoRA"> FLUX LoRA the Explorer</h1>""",
265
- elem_id="title",
266
- )
267
- selected_index = gr.State(None)
268
  with gr.Row():
269
- with gr.Column(scale=3):
270
- prompt = gr.Textbox(label="Prompt", lines=1, placeholder="Type a prompt after selecting a LoRA")
271
- with gr.Column(scale=1, elem_id="gen_column"):
272
- generate_button = gr.Button("Generate", variant="primary", elem_id="gen_btn")
 
273
  with gr.Row():
274
  with gr.Column():
275
- selected_info = gr.Markdown("")
276
- gallery = gr.Gallery(
277
- [(item["image"], item["title"]) for item in loras],
278
- label="LoRA Gallery",
279
- allow_preview=False,
280
- columns=3,
281
- elem_id="gallery",
282
- show_share_button=False
283
- )
284
- with gr.Group():
285
- custom_lora = gr.Textbox(label="Custom LoRA", info="LoRA Hugging Face path", placeholder="multimodalart/vintage-ads-flux")
286
- gr.Markdown("[Check the list of FLUX LoRas](https://huggingface.co/models?other=base_model:adapter:black-forest-labs/FLUX.1-dev)", elem_id="lora_list")
287
- custom_lora_info = gr.HTML(visible=False)
288
- custom_lora_button = gr.Button("Remove custom LoRA", visible=False)
289
  with gr.Column():
290
- progress_bar = gr.Markdown(elem_id="progress",visible=False)
291
- result = gr.Image(label="Generated Image")
 
292
 
293
  with gr.Row():
294
- with gr.Accordion("Advanced Settings", open=False):
295
- with gr.Row():
296
- input_image = gr.Image(label="Input image", type="filepath")
297
- image_strength = gr.Slider(label="Denoise Strength", info="Lower means more image influence", minimum=0.1, maximum=1.0, step=0.01, value=0.75)
298
- with gr.Column():
299
- with gr.Row():
300
- cfg_scale = gr.Slider(label="CFG Scale", minimum=1, maximum=20, step=0.5, value=3.5)
301
- steps = gr.Slider(label="Steps", minimum=1, maximum=50, step=1, value=28)
302
-
303
- with gr.Row():
304
- width = gr.Slider(label="Width", minimum=256, maximum=1536, step=64, value=1024)
305
- height = gr.Slider(label="Height", minimum=256, maximum=1536, step=64, value=1024)
306
-
307
- with gr.Row():
308
- randomize_seed = gr.Checkbox(True, label="Randomize seed")
309
- seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0, randomize=True)
310
- lora_scale = gr.Slider(label="LoRA Scale", minimum=0, maximum=3, step=0.01, value=0.95)
311
 
312
- gallery.select(
313
- update_selection,
314
- inputs=[width, height],
315
- outputs=[prompt, selected_info, selected_index, width, height]
316
- )
317
- custom_lora.input(
318
- add_custom_lora,
319
- inputs=[custom_lora],
320
- outputs=[custom_lora_info, custom_lora_button, gallery, selected_info, selected_index, prompt]
321
- )
322
- custom_lora_button.click(
323
- remove_custom_lora,
324
- outputs=[custom_lora_info, custom_lora_button, gallery, selected_info, selected_index, custom_lora]
325
- )
326
- gr.on(
327
- triggers=[generate_button.click, prompt.submit],
328
- fn=run_lora,
329
- inputs=[prompt, input_image, image_strength, cfg_scale, steps, selected_index, randomize_seed, seed, width, height, lora_scale],
330
- outputs=[result, seed, progress_bar]
331
- )
332
 
333
- app.queue()
334
- app.launch()
 
38
  def __enter__(self):
39
  self.start_time = time.time()
40
  return self
41
+
42
  def __exit__(self, exc_type, exc_value, traceback):
43
  self.end_time = time.time()
44
  self.elapsed_time = self.end_time - self.start_time
 
106
  joint_attention_kwargs={"scale": lora_scale},
107
  output_type="pil",
108
  ).images[0]
109
+ return final_image
110
 
111
  def run_lora(prompt, image_input, image_strength, cfg_scale, steps, selected_index, randomize_seed, seed, width, height, lora_scale, progress=gr.Progress(track_tqdm=True)):
112
  if selected_index is None:
113
  raise gr.Error("You must select a LoRA before proceeding.")
114
+
115
  selected_lora = loras[selected_index]
116
  lora_path = selected_lora["repo"]
117
  trigger_word = selected_lora["trigger_word"]
118
+
119
+ if trigger_word:
120
  if "trigger_position" in selected_lora:
121
  if selected_lora["trigger_position"] == "prepend":
122
  prompt_mash = f"{trigger_word} {prompt}"
 
130
  with calculateDuration("Unloading LoRA"):
131
  pipe.unload_lora_weights()
132
  pipe_i2i.unload_lora_weights()
133
+
134
  # Load LoRA weights
135
  with calculateDuration(f"Loading LoRA weights for {selected_lora['title']}"):
136
+ if image_input is not None:
137
  if "weights" in selected_lora:
138
  pipe_i2i.load_lora_weights(lora_path, weight_name=selected_lora["weights"])
139
  else:
 
143
  pipe.load_lora_weights(lora_path, weight_name=selected_lora["weights"])
144
  else:
145
  pipe.load_lora_weights(lora_path)
146
+
147
  # Set random seed for reproducibility
148
  with calculateDuration("Randomizing seed"):
149
  if randomize_seed:
150
  seed = random.randint(0, MAX_SEED)
151
+
152
+ if image_input is not None:
 
153
  final_image = generate_image_to_image(prompt_mash, image_input, image_strength, steps, cfg_scale, width, height, lora_scale, seed)
154
  yield final_image, seed, gr.update(visible=False)
155
  else:
156
  image_generator = generate_image(prompt_mash, steps, seed, cfg_scale, width, height, lora_scale, progress)
 
157
  # Consume the generator to get the final image
158
  final_image = None
159
  step_counter = 0
160
  for image in image_generator:
161
+ step_counter += 1
162
  final_image = image
163
+ progress_bar = f'Generating image... Step {step_counter}/{steps}'
164
+ yield image, seed, gr.update(visible=True, value=progress_bar)
 
 
165
 
166
+ yield final_image, seed, gr.update(visible=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
 
168
+ # Gradio interface
169
+ with gr.Blocks() as demo:
170
+ gr.Markdown("# Awaken Ones' Lora Previews")
171
+ gr.Markdown("Select a LoRA model from the gallery below to get started!")
172
+
173
+ with gr.Row():
174
+ gallery = gr.Gallery(
175
+ value=[lora["image"] for lora in loras],
176
+ label="LoRA Gallery",
177
+ show_label=False,
178
+ elem_id="gallery",
179
+ columns=[5],
180
+ rows=[3],
181
+ object_fit="contain",
182
+ height="auto",
183
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
184
 
185
+ with gr.Row():
186
+ prompt = gr.Textbox(
187
+ label="Prompt",
188
+ placeholder="Type your prompt here...",
189
+ show_label=True,
190
+ )
191
+ image_input = gr.Image(type="filepath", label="Image Input (Optional)")
192
 
193
+ with gr.Row():
194
+ generate = gr.Button("Generate", variant="primary")
195
+ cancel = gr.Button("Cancel")
196
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
197
  with gr.Row():
198
+ with gr.Column(scale=4):
199
+ result = gr.Image(label="Result", show_label=False, elem_id="result")
200
+ with gr.Column(scale=1):
201
+ seed_output = gr.Number(label="Seed", interactive=False)
202
+
203
  with gr.Row():
204
  with gr.Column():
205
+ steps = gr.Slider(minimum=1, maximum=100, value=28, step=1, label="Steps")
206
+ cfg_scale = gr.Slider(minimum=1, maximum=20, value=3.5, step=0.1, label="CFG Scale")
207
+ lora_scale = gr.Slider(minimum=0, maximum=1, value=0.8, step=0.05, label="LoRA Scale")
 
 
 
 
 
 
 
 
 
 
 
208
  with gr.Column():
209
+ width = gr.Slider(minimum=256, maximum=1024, value=512, step=64, label="Width")
210
+ height = gr.Slider(minimum=256, maximum=1024, value=512, step=64, label="Height")
211
+ image_strength = gr.Slider(minimum=0, maximum=1, value=0.8, step=0.05, label="Image Strength")
212
 
213
  with gr.Row():
214
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
215
+ seed_input = gr.Number(label="Seed", value=0, interactive=True, visible=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
216
 
217
+ selected_lora = gr.Markdown("### No LoRA selected")
218
+ progress_bar = gr.Markdown(visible=False)
219
+
220
+ # Event handlers
221
+ gallery.select(update_selection, [width, height], [prompt, selected_lora, gr.State(), width, height])
222
+ randomize_seed.change(lambda x: gr.update(visible=not x), randomize_seed, seed_input)
223
+ generate.click(run_lora, inputs=[prompt, image_input, image_strength, cfg_scale, steps, gr.State(), randomize_seed, seed_input, width, height, lora_scale], outputs=[result, seed_output, progress_bar])
224
+ cancel.click(lambda: None, None, None, cancels=[generate])
 
 
 
 
 
 
 
 
 
 
 
 
225
 
226
+ demo.queue().launch()