Update app.py
Browse files
app.py
CHANGED
@@ -316,74 +316,71 @@ def randomize_loras(selected_indices, loras_state):
|
|
316 |
random_prompt = random.choice(prompt_values)
|
317 |
return selected_info_1, selected_info_2, selected_info_3, selected_info_4, selected_indices, lora_scale_1, lora_scale_2, lora_scale_3, lora_scale_4, lora_image_1, lora_image_2, lora_image_3, lora_image_4, random_prompt
|
318 |
|
319 |
-
def add_custom_lora(custom_lora, selected_indices, current_loras, gallery):
|
320 |
-
if custom_lora:
|
321 |
-
|
322 |
-
|
323 |
-
|
324 |
-
|
325 |
-
|
326 |
-
|
327 |
-
|
328 |
-
|
329 |
-
|
330 |
-
"title": title,
|
331 |
-
"repo": repo,
|
332 |
-
"weights": path,
|
333 |
-
"trigger_word": trigger_word
|
334 |
-
}
|
335 |
-
print(f"New LoRA: {new_item}")
|
336 |
-
existing_item_index = len(current_loras)
|
337 |
-
current_loras.append(new_item)
|
338 |
-
|
339 |
-
# Update gallery
|
340 |
-
gallery_items = [(item["image"], item["title"]) for item in current_loras]
|
341 |
-
# Update selected_indices if there's room
|
342 |
-
if len(selected_indices) < 4:
|
343 |
-
selected_indices.append(existing_item_index)
|
344 |
-
else:
|
345 |
-
gr.Warning("You can select up to 4 LoRAs, remove one to select a new one.")
|
346 |
-
|
347 |
-
# Update selected_info and images
|
348 |
-
selected_info_1 = "Select a Celebrity as LoRA 1"
|
349 |
-
selected_info_2 = "Select a LoRA 2"
|
350 |
-
selected_info_3 = "Select a LoRA 3"
|
351 |
-
selected_info_4 = "Select a LoRA 4"
|
352 |
-
lora_scale_1 = 1.15
|
353 |
-
lora_scale_2 = 1.15
|
354 |
-
lora_scale_3 = 0.65
|
355 |
-
lora_scale_4 = 0.65
|
356 |
-
lora_image_1 = None
|
357 |
-
lora_image_2 = None
|
358 |
-
lora_image_3 = None
|
359 |
-
lora_image_4 = None
|
360 |
-
if len(selected_indices) >= 1:
|
361 |
-
lora1 = current_loras[selected_indices[0]]
|
362 |
-
selected_info_1 = f"### LoRA 1 Selected: {lora1['title']} ✨"
|
363 |
-
lora_image_1 = lora1['image'] if lora1['image'] else None
|
364 |
|
365 |
-
|
366 |
-
|
367 |
-
selected_info_2 = f"### LoRA 2 Selected: {lora2['title']} ✨"
|
368 |
-
lora_image_2 = lora2['image'] if lora2['image'] else None
|
369 |
|
370 |
-
|
371 |
-
|
372 |
-
|
373 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
374 |
|
375 |
-
|
376 |
-
|
377 |
-
|
378 |
-
|
379 |
-
|
380 |
-
|
381 |
-
|
382 |
-
|
383 |
-
|
384 |
-
|
385 |
-
|
386 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
387 |
|
388 |
def remove_custom_lora(selected_indices, current_loras, gallery):
|
389 |
if current_loras:
|
@@ -519,10 +516,10 @@ def run_lora(prompt, cfg_scale, steps, selected_indices, lora_scale_1, lora_scal
|
|
519 |
|
520 |
run_lora.zerogpu = True
|
521 |
|
522 |
-
def get_huggingface_safetensors(link):
|
523 |
split_link = link.split("/")
|
524 |
if len(split_link) == 2:
|
525 |
-
model_card = ModelCard.load(link)
|
526 |
base_model = model_card.data.get("base_model")
|
527 |
print(f"Base model: {base_model}")
|
528 |
if base_model not in ["black-forest-labs/FLUX.1-dev", "black-forest-labs/FLUX.1-schnell"]:
|
@@ -530,7 +527,7 @@ def get_huggingface_safetensors(link):
|
|
530 |
image_path = model_card.data.get("widget", [{}])[0].get("output", {}).get("url", None)
|
531 |
trigger_word = model_card.data.get("instance_prompt", "")
|
532 |
image_url = f"https://huggingface.co/{link}/resolve/main/{image_path}" if image_path else None
|
533 |
-
fs = HfFileSystem()
|
534 |
safetensors_name = None
|
535 |
try:
|
536 |
list_of_files = fs.ls(link, detail=False)
|
@@ -549,24 +546,22 @@ def get_huggingface_safetensors(link):
|
|
549 |
else:
|
550 |
raise gr.Error("Invalid Hugging Face repository link")
|
551 |
|
552 |
-
def check_custom_model(link):
|
553 |
if link.endswith(".safetensors"):
|
554 |
-
# Treat as direct link to the LoRA weights
|
555 |
title = os.path.basename(link)
|
556 |
repo = link
|
557 |
-
path = None
|
558 |
trigger_word = ""
|
559 |
image_url = None
|
560 |
return title, repo, path, trigger_word, image_url
|
561 |
elif link.startswith("https://"):
|
562 |
if "huggingface.co" in link:
|
563 |
link_split = link.split("huggingface.co/")
|
564 |
-
return get_huggingface_safetensors(link_split[1])
|
565 |
else:
|
566 |
raise Exception("Unsupported URL")
|
567 |
else:
|
568 |
-
|
569 |
-
return get_huggingface_safetensors(link)
|
570 |
|
571 |
def update_history(new_image, history):
|
572 |
"""Updates the history gallery with the new image."""
|
|
|
316 |
random_prompt = random.choice(prompt_values)
|
317 |
return selected_info_1, selected_info_2, selected_info_3, selected_info_4, selected_indices, lora_scale_1, lora_scale_2, lora_scale_3, lora_scale_4, lora_image_1, lora_image_2, lora_image_3, lora_image_4, random_prompt
|
318 |
|
319 |
+
def add_custom_lora(custom_lora, selected_indices, current_loras, gallery, request: gr.Request = None):
|
320 |
+
if not custom_lora:
|
321 |
+
return current_loras, gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), selected_indices, gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update()
|
322 |
+
|
323 |
+
try:
|
324 |
+
# Retrieve user token if running in Spaces
|
325 |
+
user_token = request.headers.get("Authorization", "").replace("Bearer ", "") if request else None
|
326 |
+
|
327 |
+
# Check and load custom LoRA
|
328 |
+
title, repo, path, trigger_word, image = check_custom_model(custom_lora, token=user_token)
|
329 |
+
print(f"Loaded custom LoRA: {repo}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
330 |
|
331 |
+
# Check if the LoRA already exists in the current list
|
332 |
+
existing_item_index = next((index for (index, item) in enumerate(current_loras) if item['repo'] == repo), None)
|
|
|
|
|
333 |
|
334 |
+
if existing_item_index is None:
|
335 |
+
# Download if a direct .safetensors URL
|
336 |
+
if repo.endswith(".safetensors") and repo.startswith("http"):
|
337 |
+
repo = download_file(repo)
|
338 |
+
|
339 |
+
# Add the new LoRA
|
340 |
+
new_item = {
|
341 |
+
"image": image or "/home/user/app/custom.png",
|
342 |
+
"title": title,
|
343 |
+
"repo": repo,
|
344 |
+
"weights": path,
|
345 |
+
"trigger_word": trigger_word,
|
346 |
+
}
|
347 |
+
print(f"New LoRA: {new_item}")
|
348 |
+
existing_item_index = len(current_loras)
|
349 |
+
current_loras.append(new_item)
|
350 |
|
351 |
+
# Update gallery items
|
352 |
+
gallery_items = [(item["image"], item["title"]) for item in current_loras]
|
353 |
+
|
354 |
+
# Update selected indices
|
355 |
+
if len(selected_indices) < 4:
|
356 |
+
selected_indices.append(existing_item_index)
|
357 |
+
else:
|
358 |
+
raise gr.Error("You can select up to 4 LoRAs. Please remove one to add a new one.")
|
359 |
+
|
360 |
+
# Update selection info and images
|
361 |
+
selected_info = [f"Select a LoRA {i + 1}" for i in range(4)]
|
362 |
+
lora_images = [None] * 4
|
363 |
+
lora_scales = [1.15, 1.15, 0.65, 0.65]
|
364 |
+
|
365 |
+
for idx, sel_idx in enumerate(selected_indices[:4]):
|
366 |
+
lora = current_loras[sel_idx]
|
367 |
+
selected_info[idx] = f"### LoRA {idx + 1} Selected: {lora['title']} ✨"
|
368 |
+
lora_images[idx] = lora.get("image")
|
369 |
+
|
370 |
+
print("Finished adding custom LoRA")
|
371 |
+
return (
|
372 |
+
current_loras,
|
373 |
+
gr.update(value=gallery_items),
|
374 |
+
*selected_info,
|
375 |
+
selected_indices,
|
376 |
+
*lora_scales,
|
377 |
+
*lora_images,
|
378 |
+
)
|
379 |
+
|
380 |
+
except Exception as e:
|
381 |
+
print(e)
|
382 |
+
return (current_loras, gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), selected_indices, gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(),gr.update(),
|
383 |
+
)
|
384 |
|
385 |
def remove_custom_lora(selected_indices, current_loras, gallery):
|
386 |
if current_loras:
|
|
|
516 |
|
517 |
run_lora.zerogpu = True
|
518 |
|
519 |
+
def get_huggingface_safetensors(link, token=None):
|
520 |
split_link = link.split("/")
|
521 |
if len(split_link) == 2:
|
522 |
+
model_card = ModelCard.load(link, use_auth_token=token)
|
523 |
base_model = model_card.data.get("base_model")
|
524 |
print(f"Base model: {base_model}")
|
525 |
if base_model not in ["black-forest-labs/FLUX.1-dev", "black-forest-labs/FLUX.1-schnell"]:
|
|
|
527 |
image_path = model_card.data.get("widget", [{}])[0].get("output", {}).get("url", None)
|
528 |
trigger_word = model_card.data.get("instance_prompt", "")
|
529 |
image_url = f"https://huggingface.co/{link}/resolve/main/{image_path}" if image_path else None
|
530 |
+
fs = HfFileSystem(token=token)
|
531 |
safetensors_name = None
|
532 |
try:
|
533 |
list_of_files = fs.ls(link, detail=False)
|
|
|
546 |
else:
|
547 |
raise gr.Error("Invalid Hugging Face repository link")
|
548 |
|
549 |
+
def check_custom_model(link, token=None):
|
550 |
if link.endswith(".safetensors"):
|
|
|
551 |
title = os.path.basename(link)
|
552 |
repo = link
|
553 |
+
path = None
|
554 |
trigger_word = ""
|
555 |
image_url = None
|
556 |
return title, repo, path, trigger_word, image_url
|
557 |
elif link.startswith("https://"):
|
558 |
if "huggingface.co" in link:
|
559 |
link_split = link.split("huggingface.co/")
|
560 |
+
return get_huggingface_safetensors(link_split[1], token=token)
|
561 |
else:
|
562 |
raise Exception("Unsupported URL")
|
563 |
else:
|
564 |
+
return get_huggingface_safetensors(link, token=token)
|
|
|
565 |
|
566 |
def update_history(new_image, history):
|
567 |
"""Updates the history gallery with the new image."""
|