Keltezaa commited on
Commit
b50b66c
·
verified ·
1 Parent(s): 715ed1a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +69 -74
app.py CHANGED
@@ -316,74 +316,71 @@ def randomize_loras(selected_indices, loras_state):
316
  random_prompt = random.choice(prompt_values)
317
  return selected_info_1, selected_info_2, selected_info_3, selected_info_4, selected_indices, lora_scale_1, lora_scale_2, lora_scale_3, lora_scale_4, lora_image_1, lora_image_2, lora_image_3, lora_image_4, random_prompt
318
 
319
- def add_custom_lora(custom_lora, selected_indices, current_loras, gallery):
320
- if custom_lora:
321
- try:
322
- title, repo, path, trigger_word, image = check_custom_model(custom_lora)
323
- print(f"Loaded custom LoRA: {repo}")
324
- existing_item_index = next((index for (index, item) in enumerate(current_loras) if item['repo'] == repo), None)
325
- if existing_item_index is None:
326
- if repo.endswith(".safetensors") and repo.startswith("http"):
327
- repo = download_file(repo)
328
- new_item = {
329
- "image": image if image else "/home/user/app/custom.png",
330
- "title": title,
331
- "repo": repo,
332
- "weights": path,
333
- "trigger_word": trigger_word
334
- }
335
- print(f"New LoRA: {new_item}")
336
- existing_item_index = len(current_loras)
337
- current_loras.append(new_item)
338
-
339
- # Update gallery
340
- gallery_items = [(item["image"], item["title"]) for item in current_loras]
341
- # Update selected_indices if there's room
342
- if len(selected_indices) < 4:
343
- selected_indices.append(existing_item_index)
344
- else:
345
- gr.Warning("You can select up to 4 LoRAs, remove one to select a new one.")
346
-
347
- # Update selected_info and images
348
- selected_info_1 = "Select a Celebrity as LoRA 1"
349
- selected_info_2 = "Select a LoRA 2"
350
- selected_info_3 = "Select a LoRA 3"
351
- selected_info_4 = "Select a LoRA 4"
352
- lora_scale_1 = 1.15
353
- lora_scale_2 = 1.15
354
- lora_scale_3 = 0.65
355
- lora_scale_4 = 0.65
356
- lora_image_1 = None
357
- lora_image_2 = None
358
- lora_image_3 = None
359
- lora_image_4 = None
360
- if len(selected_indices) >= 1:
361
- lora1 = current_loras[selected_indices[0]]
362
- selected_info_1 = f"### LoRA 1 Selected: {lora1['title']} ✨"
363
- lora_image_1 = lora1['image'] if lora1['image'] else None
364
 
365
- if len(selected_indices) >= 2:
366
- lora2 = current_loras[selected_indices[1]]
367
- selected_info_2 = f"### LoRA 2 Selected: {lora2['title']} ✨"
368
- lora_image_2 = lora2['image'] if lora2['image'] else None
369
 
370
- if len(selected_indices) >= 3:
371
- lora3 = current_loras[selected_indices[2]]
372
- selected_info_3 = f"### LoRA 3 Selected: {lora3['title']} ✨"
373
- lora_image_3 = lora3['image'] if lora3['image'] else None
 
 
 
 
 
 
 
 
 
 
 
 
374
 
375
- if len(selected_indices) >= 4:
376
- lora4 = current_loras[selected_indices[3]]
377
- selected_info_4 = f"### LoRA 4 Selected: {lora4['title']} ✨"
378
- lora_image_4 = lora4['image'] if lora4['image'] else None
379
- print("Finished adding custom LoRA")
380
- return (current_loras, gr.update(value=gallery_items), selected_info_1, selected_info_2, selected_info_3, selected_info_4, selected_indices, lora_scale_1, lora_scale_2, lora_scale_3, lora_scale_4, lora_image_1, lora_image_2, lora_image_3, lora_image_4)
381
- except Exception as e:
382
- print(e)
383
- gr.Warning(str(e))
384
- return current_loras, gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), selected_indices, gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update()
385
- else:
386
- return current_loras, gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), selected_indices, gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
387
 
388
  def remove_custom_lora(selected_indices, current_loras, gallery):
389
  if current_loras:
@@ -519,10 +516,10 @@ def run_lora(prompt, cfg_scale, steps, selected_indices, lora_scale_1, lora_scal
519
 
520
  run_lora.zerogpu = True
521
 
522
- def get_huggingface_safetensors(link):
523
  split_link = link.split("/")
524
  if len(split_link) == 2:
525
- model_card = ModelCard.load(link)
526
  base_model = model_card.data.get("base_model")
527
  print(f"Base model: {base_model}")
528
  if base_model not in ["black-forest-labs/FLUX.1-dev", "black-forest-labs/FLUX.1-schnell"]:
@@ -530,7 +527,7 @@ def get_huggingface_safetensors(link):
530
  image_path = model_card.data.get("widget", [{}])[0].get("output", {}).get("url", None)
531
  trigger_word = model_card.data.get("instance_prompt", "")
532
  image_url = f"https://huggingface.co/{link}/resolve/main/{image_path}" if image_path else None
533
- fs = HfFileSystem()
534
  safetensors_name = None
535
  try:
536
  list_of_files = fs.ls(link, detail=False)
@@ -549,24 +546,22 @@ def get_huggingface_safetensors(link):
549
  else:
550
  raise gr.Error("Invalid Hugging Face repository link")
551
 
552
- def check_custom_model(link):
553
  if link.endswith(".safetensors"):
554
- # Treat as direct link to the LoRA weights
555
  title = os.path.basename(link)
556
  repo = link
557
- path = None # No specific weight name
558
  trigger_word = ""
559
  image_url = None
560
  return title, repo, path, trigger_word, image_url
561
  elif link.startswith("https://"):
562
  if "huggingface.co" in link:
563
  link_split = link.split("huggingface.co/")
564
- return get_huggingface_safetensors(link_split[1])
565
  else:
566
  raise Exception("Unsupported URL")
567
  else:
568
- # Assume it's a Hugging Face model path
569
- return get_huggingface_safetensors(link)
570
 
571
  def update_history(new_image, history):
572
  """Updates the history gallery with the new image."""
 
316
  random_prompt = random.choice(prompt_values)
317
  return selected_info_1, selected_info_2, selected_info_3, selected_info_4, selected_indices, lora_scale_1, lora_scale_2, lora_scale_3, lora_scale_4, lora_image_1, lora_image_2, lora_image_3, lora_image_4, random_prompt
318
 
319
+ def add_custom_lora(custom_lora, selected_indices, current_loras, gallery, request: gr.Request = None):
320
+ if not custom_lora:
321
+ return current_loras, gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), selected_indices, gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update()
322
+
323
+ try:
324
+ # Retrieve user token if running in Spaces
325
+ user_token = request.headers.get("Authorization", "").replace("Bearer ", "") if request else None
326
+
327
+ # Check and load custom LoRA
328
+ title, repo, path, trigger_word, image = check_custom_model(custom_lora, token=user_token)
329
+ print(f"Loaded custom LoRA: {repo}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
330
 
331
+ # Check if the LoRA already exists in the current list
332
+ existing_item_index = next((index for (index, item) in enumerate(current_loras) if item['repo'] == repo), None)
 
 
333
 
334
+ if existing_item_index is None:
335
+ # Download if a direct .safetensors URL
336
+ if repo.endswith(".safetensors") and repo.startswith("http"):
337
+ repo = download_file(repo)
338
+
339
+ # Add the new LoRA
340
+ new_item = {
341
+ "image": image or "/home/user/app/custom.png",
342
+ "title": title,
343
+ "repo": repo,
344
+ "weights": path,
345
+ "trigger_word": trigger_word,
346
+ }
347
+ print(f"New LoRA: {new_item}")
348
+ existing_item_index = len(current_loras)
349
+ current_loras.append(new_item)
350
 
351
+ # Update gallery items
352
+ gallery_items = [(item["image"], item["title"]) for item in current_loras]
353
+
354
+ # Update selected indices
355
+ if len(selected_indices) < 4:
356
+ selected_indices.append(existing_item_index)
357
+ else:
358
+ raise gr.Error("You can select up to 4 LoRAs. Please remove one to add a new one.")
359
+
360
+ # Update selection info and images
361
+ selected_info = [f"Select a LoRA {i + 1}" for i in range(4)]
362
+ lora_images = [None] * 4
363
+ lora_scales = [1.15, 1.15, 0.65, 0.65]
364
+
365
+ for idx, sel_idx in enumerate(selected_indices[:4]):
366
+ lora = current_loras[sel_idx]
367
+ selected_info[idx] = f"### LoRA {idx + 1} Selected: {lora['title']} ✨"
368
+ lora_images[idx] = lora.get("image")
369
+
370
+ print("Finished adding custom LoRA")
371
+ return (
372
+ current_loras,
373
+ gr.update(value=gallery_items),
374
+ *selected_info,
375
+ selected_indices,
376
+ *lora_scales,
377
+ *lora_images,
378
+ )
379
+
380
+ except Exception as e:
381
+ print(e)
382
+ return (current_loras, gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), selected_indices, gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(),gr.update(),
383
+ )
384
 
385
  def remove_custom_lora(selected_indices, current_loras, gallery):
386
  if current_loras:
 
516
 
517
  run_lora.zerogpu = True
518
 
519
+ def get_huggingface_safetensors(link, token=None):
520
  split_link = link.split("/")
521
  if len(split_link) == 2:
522
+ model_card = ModelCard.load(link, use_auth_token=token)
523
  base_model = model_card.data.get("base_model")
524
  print(f"Base model: {base_model}")
525
  if base_model not in ["black-forest-labs/FLUX.1-dev", "black-forest-labs/FLUX.1-schnell"]:
 
527
  image_path = model_card.data.get("widget", [{}])[0].get("output", {}).get("url", None)
528
  trigger_word = model_card.data.get("instance_prompt", "")
529
  image_url = f"https://huggingface.co/{link}/resolve/main/{image_path}" if image_path else None
530
+ fs = HfFileSystem(token=token)
531
  safetensors_name = None
532
  try:
533
  list_of_files = fs.ls(link, detail=False)
 
546
  else:
547
  raise gr.Error("Invalid Hugging Face repository link")
548
 
549
+ def check_custom_model(link, token=None):
550
  if link.endswith(".safetensors"):
 
551
  title = os.path.basename(link)
552
  repo = link
553
+ path = None
554
  trigger_word = ""
555
  image_url = None
556
  return title, repo, path, trigger_word, image_url
557
  elif link.startswith("https://"):
558
  if "huggingface.co" in link:
559
  link_split = link.split("huggingface.co/")
560
+ return get_huggingface_safetensors(link_split[1], token=token)
561
  else:
562
  raise Exception("Unsupported URL")
563
  else:
564
+ return get_huggingface_safetensors(link, token=token)
 
565
 
566
  def update_history(new_image, history):
567
  """Updates the history gallery with the new image."""