John6666 commited on
Commit
4b6439a
1 Parent(s): b7778c4

Upload 6 files

Browse files
Files changed (6) hide show
  1. app.py +7 -3
  2. dc.py +101 -41
  3. env.py +5 -0
  4. llmdolphin.py +40 -0
  5. modutils.py +63 -35
  6. requirements.txt +0 -1
app.py CHANGED
@@ -138,9 +138,13 @@ with gr.Blocks(fill_width=True, elem_id="container", css=css, delete_cache=(60,
138
  lora5_copy = gr.Button(value="Copy example to prompt", visible=False)
139
  lora5_md = gr.Markdown(value="", visible=False)
140
  with gr.Accordion("From URL", open=True, visible=True):
 
 
 
 
141
  with gr.Row():
142
  lora_search_civitai_query = gr.Textbox(label="Query", placeholder="oomuro sakurako...", lines=1)
143
- lora_search_civitai_basemodel = gr.CheckboxGroup(label="Search LoRA for", choices=["Pony", "SD 1.5", "SDXL 1.0"], value=["Pony", "SDXL 1.0"])
144
  lora_search_civitai_submit = gr.Button("Search on Civitai")
145
  with gr.Row():
146
  lora_search_civitai_result = gr.Dropdown(label="Search Results", choices=[("", "")], value="", allow_custom_value=True, visible=False)
@@ -247,9 +251,9 @@ with gr.Blocks(fill_width=True, elem_id="container", css=css, delete_cache=(60,
247
  lora5_copy.click(apply_lora_prompt, [prompt, lora5_info], [prompt], queue=False, show_api=False)
248
 
249
  gr.on(
250
- triggers=[lora_search_civitai_submit.click, lora_search_civitai_query.submit],
251
  fn=search_civitai_lora,
252
- inputs=[lora_search_civitai_query, lora_search_civitai_basemodel],
253
  outputs=[lora_search_civitai_result, lora_search_civitai_desc, lora_search_civitai_submit, lora_search_civitai_query],
254
  scroll_to_output=True,
255
  queue=True,
 
138
  lora5_copy = gr.Button(value="Copy example to prompt", visible=False)
139
  lora5_md = gr.Markdown(value="", visible=False)
140
  with gr.Accordion("From URL", open=True, visible=True):
141
+ with gr.Row():
142
+ lora_search_civitai_basemodel = gr.CheckboxGroup(label="Search LoRA for", choices=["Pony", "SD 1.5", "SDXL 1.0", "Flux.1 D", "Flux.1 S"], value=["Pony", "SDXL 1.0"])
143
+ lora_search_civitai_sort = gr.Radio(label="Sort", choices=["Highest Rated", "Most Downloaded", "Newest"], value="Highest Rated")
144
+ lora_search_civitai_period = gr.Radio(label="Period", choices=["AllTime", "Year", "Month", "Week", "Day"], value="AllTime")
145
  with gr.Row():
146
  lora_search_civitai_query = gr.Textbox(label="Query", placeholder="oomuro sakurako...", lines=1)
147
+ lora_search_civitai_tag = gr.Textbox(label="Tag", lines=1)
148
  lora_search_civitai_submit = gr.Button("Search on Civitai")
149
  with gr.Row():
150
  lora_search_civitai_result = gr.Dropdown(label="Search Results", choices=[("", "")], value="", allow_custom_value=True, visible=False)
 
251
  lora5_copy.click(apply_lora_prompt, [prompt, lora5_info], [prompt], queue=False, show_api=False)
252
 
253
  gr.on(
254
+ triggers=[lora_search_civitai_submit.click, lora_search_civitai_query.submit, lora_search_civitai_tag.submit],
255
  fn=search_civitai_lora,
256
+ inputs=[lora_search_civitai_query, lora_search_civitai_basemodel, lora_search_civitai_sort, lora_search_civitai_period, lora_search_civitai_tag],
257
  outputs=[lora_search_civitai_result, lora_search_civitai_desc, lora_search_civitai_submit, lora_search_civitai_query],
258
  scroll_to_output=True,
259
  queue=True,
dc.py CHANGED
@@ -21,12 +21,9 @@ from stablepy import (
21
  SD15_TASKS,
22
  SDXL_TASKS,
23
  )
 
24
  #import urllib.parse
25
  import gradio as gr
26
- from PIL import Image
27
- import IPython.display
28
- import time, json
29
- from IPython.utils import capture
30
  import logging
31
  logging.getLogger("diffusers").setLevel(logging.ERROR)
32
  import diffusers
@@ -381,7 +378,7 @@ class GuiSD:
381
  if vae_model:
382
  vae_type = "SDXL" if "sdxl" in vae_model.lower() else "SD 1.5"
383
  if model_type != vae_type:
384
- gr.Info(msg_inc_vae)
385
 
386
  self.model.device = torch.device("cpu")
387
  dtype_model = torch.bfloat16 if model_type == "FLUX" else torch.float16
@@ -395,7 +392,7 @@ class GuiSD:
395
  )
396
  #yield f"Model loaded: {model_name}"
397
 
398
- @spaces.GPU
399
  @torch.inference_mode()
400
  def generate_pipeline(
401
  self,
@@ -508,7 +505,7 @@ class GuiSD:
508
  vae_model = vae_model if vae_model != "None" else None
509
  loras_list = [lora1, lora2, lora3, lora4, lora5]
510
  vae_msg = f"VAE: {vae_model}" if vae_model else ""
511
- msg_lora = []
512
 
513
  print("Config model:", model_name, vae_model, loras_list)
514
 
@@ -679,35 +676,94 @@ class GuiSD:
679
  return self.infer_short(self.model, pipe_params, progress), info_state
680
  ## END MOD
681
 
682
- # def sd_gen_generate_pipeline(*args):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
683
 
684
- # # Load lora in CPU
685
- # status_lora = sd_gen.model.lora_merge(
686
- # lora_A=args[7] if args[7] != "None" else None, lora_scale_A=args[8],
687
- # lora_B=args[9] if args[9] != "None" else None, lora_scale_B=args[10],
688
- # lora_C=args[11] if args[11] != "None" else None, lora_scale_C=args[12],
689
- # lora_D=args[13] if args[13] != "None" else None, lora_scale_D=args[14],
690
- # lora_E=args[15] if args[15] != "None" else None, lora_scale_E=args[16],
691
- # )
692
 
693
- # lora_list = [args[7], args[9], args[11], args[13], args[15]]
694
- # print(status_lora)
695
- # for status, lora in zip(status_lora, lora_list):
696
- # if status:
697
- # gr.Info(f"LoRA loaded: {lora}")
698
- # elif status is not None:
699
- # gr.Warning(f"Failed to load LoRA: {lora}")
700
 
701
- # # if status_lora == [None] * 5 and self.model.lora_memory != [None] * 5:
702
- # # gr.Info(f"LoRAs in cache: {", ".join(str(x) for x in self.model.lora_memory if x is not None)}")
703
 
704
- # yield from sd_gen.generate_pipeline(*args)
 
 
 
 
 
 
705
 
 
 
 
 
 
 
 
706
 
707
- # sd_gen_generate_pipeline.zerogpu = True
 
 
708
 
709
  from pathlib import Path
710
- import random
 
711
  from modutils import (safe_float, escape_lora_basename, to_lora_key, to_lora_path,
712
  get_local_model_list, get_private_lora_model_lists, get_valid_lora_name,
713
  get_valid_lora_path, get_valid_lora_wt, get_lora_info,
@@ -723,6 +779,10 @@ def infer(prompt, negative_prompt, seed, randomize_seed, width, height, guidance
723
  import numpy as np
724
  MAX_SEED = np.iinfo(np.int32).max
725
 
 
 
 
 
726
  images: list[tuple[PIL.Image.Image, str | None]] = []
727
  info: str = ""
728
  progress(0, desc="Preparing...")
@@ -739,7 +799,7 @@ def infer(prompt, negative_prompt, seed, randomize_seed, width, height, guidance
739
  prompt, negative_prompt = insert_model_recom_prompt(prompt, negative_prompt, model_name)
740
  progress(0.5, desc="Preparing...")
741
  lora1, lora1_wt, lora2, lora2_wt, lora3, lora3_wt, lora4, lora4_wt, lora5, lora5_wt = \
742
- set_prompt_loras(prompt, lora1, lora1_wt, lora2, lora2_wt, lora3, lora3_wt, lora4, lora4_wt, lora5, lora5_wt)
743
  lora1 = get_valid_lora_path(lora1)
744
  lora2 = get_valid_lora_path(lora2)
745
  lora3 = get_valid_lora_path(lora3)
@@ -748,7 +808,7 @@ def infer(prompt, negative_prompt, seed, randomize_seed, width, height, guidance
748
  progress(1, desc="Preparation completed. Starting inference preparation...")
749
 
750
  sd_gen.load_new_model(model_name, vae, TASK_MODEL_LIST[0], progress)
751
- images, info = sd_gen.generate_pipeline(prompt, negative_prompt, 1, num_inference_steps,
752
  guidance_scale, True, generator, lora1, lora1_wt, lora2, lora2_wt, lora3, lora3_wt,
753
  lora4, lora4_wt, lora5, lora5_wt, sampler,
754
  height, width, model_name, vae, TASK_MODEL_LIST[0], None, "Canny", 512, 1024,
@@ -757,7 +817,8 @@ def infer(prompt, negative_prompt, seed, randomize_seed, width, height, guidance
757
  False, True, 1, True, False, False, False, False, "./images", False, False, False, True, 1, 0.55,
758
  False, False, False, True, False, "Use same sampler", False, "", "", 0.35, True, True, False, 4, 4, 32,
759
  False, "", "", 0.35, True, True, False, 4, 4, 32,
760
- True, None, None, "plus_face", "original", 0.7, None, None, "base", "style", 0.7, 0.0, progress
 
761
  )
762
 
763
  progress(1, desc="Inference completed.")
@@ -820,7 +881,7 @@ def get_t2i_model_info(repo_id: str):
820
  if " " in repo_id or not api.repo_exists(repo_id): return ""
821
  model = api.model_info(repo_id=repo_id)
822
  except Exception as e:
823
- print(f"Error: Failed to get {repo_id}'s info. ")
824
  return ""
825
  if model.private or model.gated: return ""
826
  tags = model.tags
@@ -1013,13 +1074,13 @@ def download_my_lora(dl_urls: str, lora1: str, lora2: str, lora3: str, lora4: st
1013
  gr.update(value=lora4, choices=choices), gr.update(value=lora5, choices=choices)
1014
 
1015
 
1016
- def set_prompt_loras(prompt, lora1, lora1_wt, lora2, lora2_wt, lora3, lora3_wt, lora4, lora4_wt, lora5, lora5_wt):
1017
  import re
1018
- lora1 = get_valid_lora_name(lora1)
1019
- lora2 = get_valid_lora_name(lora2)
1020
- lora3 = get_valid_lora_name(lora3)
1021
- lora4 = get_valid_lora_name(lora4)
1022
- lora5 = get_valid_lora_name(lora5)
1023
  if not "<lora" in prompt: return lora1, lora1_wt, lora2, lora2_wt, lora3, lora3_wt, lora4, lora4_wt, lora5, lora5_wt
1024
  lora1_wt = get_valid_lora_wt(prompt, lora1, lora1_wt)
1025
  lora2_wt = get_valid_lora_wt(prompt, lora2, lora2_wt)
@@ -1129,9 +1190,9 @@ def update_loras(prompt, lora1, lora1_wt, lora2, lora2_wt, lora3, lora3_wt, lora
1129
  gr.update(value=tag5, label=label5, visible=on5), gr.update(visible=on5), gr.update(value=md5, visible=on5)
1130
 
1131
 
1132
- def search_civitai_lora(query, base_model):
1133
  global civitai_lora_last_results
1134
- items = search_lora_on_civitai(query, base_model)
1135
  if not items: return gr.update(choices=[("", "")], value="", visible=False),\
1136
  gr.update(value="", visible=False), gr.update(visible=True), gr.update(visible=True)
1137
  civitai_lora_last_results = {}
@@ -1324,7 +1385,6 @@ def process_style_prompt(prompt: str, neg_prompt: str, styles_key: str = "None",
1324
  return gr.update(value=prompt), gr.update(value=neg_prompt)
1325
 
1326
 
1327
- from PIL import Image
1328
  def save_images(images: list[Image.Image], metadatas: list[str]):
1329
  from PIL import PngImagePlugin
1330
  try:
 
21
  SD15_TASKS,
22
  SDXL_TASKS,
23
  )
24
+ import time
25
  #import urllib.parse
26
  import gradio as gr
 
 
 
 
27
  import logging
28
  logging.getLogger("diffusers").setLevel(logging.ERROR)
29
  import diffusers
 
378
  if vae_model:
379
  vae_type = "SDXL" if "sdxl" in vae_model.lower() else "SD 1.5"
380
  if model_type != vae_type:
381
+ gr.Warning(msg_inc_vae)
382
 
383
  self.model.device = torch.device("cpu")
384
  dtype_model = torch.bfloat16 if model_type == "FLUX" else torch.float16
 
392
  )
393
  #yield f"Model loaded: {model_name}"
394
 
395
+ #@spaces.GPU
396
  @torch.inference_mode()
397
  def generate_pipeline(
398
  self,
 
505
  vae_model = vae_model if vae_model != "None" else None
506
  loras_list = [lora1, lora2, lora3, lora4, lora5]
507
  vae_msg = f"VAE: {vae_model}" if vae_model else ""
508
+ msg_lora = ""
509
 
510
  print("Config model:", model_name, vae_model, loras_list)
511
 
 
676
  return self.infer_short(self.model, pipe_params, progress), info_state
677
  ## END MOD
678
 
679
+ def dynamic_gpu_duration(func, duration, *args):
680
+
681
+ @spaces.GPU(duration=duration)
682
+ def wrapped_func():
683
+ yield from func(*args)
684
+
685
+ return wrapped_func()
686
+
687
+
688
+ @spaces.GPU
689
+ def dummy_gpu():
690
+ return None
691
+
692
+
693
+ def sd_gen_generate_pipeline(*args):
694
+
695
+ gpu_duration_arg = int(args[-1]) if args[-1] else 59
696
+ verbose_arg = int(args[-2])
697
+ load_lora_cpu = args[-3]
698
+ generation_args = args[:-3]
699
+ lora_list = [
700
+ None if item == "None" or item == "" else item
701
+ for item in [args[7], args[9], args[11], args[13], args[15]]
702
+ ]
703
+ lora_status = [None] * 5
704
+
705
+ msg_load_lora = "Updating LoRAs in GPU..."
706
+ if load_lora_cpu:
707
+ msg_load_lora = "Updating LoRAs in CPU (Slow but saves GPU usage)..."
708
+
709
+ if lora_list != sd_gen.model.lora_memory and lora_list != [None] * 5:
710
+ yield None, msg_load_lora
711
+
712
+ # Load lora in CPU
713
+ if load_lora_cpu:
714
+ lora_status = sd_gen.model.lora_merge(
715
+ lora_A=lora_list[0], lora_scale_A=args[8],
716
+ lora_B=lora_list[1], lora_scale_B=args[10],
717
+ lora_C=lora_list[2], lora_scale_C=args[12],
718
+ lora_D=lora_list[3], lora_scale_D=args[14],
719
+ lora_E=lora_list[4], lora_scale_E=args[16],
720
+ )
721
+ print(lora_status)
722
+
723
+ if verbose_arg:
724
+ for status, lora in zip(lora_status, lora_list):
725
+ if status:
726
+ gr.Info(f"LoRA loaded in CPU: {lora}")
727
+ elif status is not None:
728
+ gr.Warning(f"Failed to load LoRA: {lora}")
729
+
730
+ if lora_status == [None] * 5 and sd_gen.model.lora_memory != [None] * 5 and load_lora_cpu:
731
+ lora_cache_msg = ", ".join(
732
+ str(x) for x in sd_gen.model.lora_memory if x is not None
733
+ )
734
+ gr.Info(f"LoRAs in cache: {lora_cache_msg}")
735
 
736
+ msg_request = f"Requesting {gpu_duration_arg}s. of GPU time"
737
+ gr.Info(msg_request)
738
+ print(msg_request)
 
 
 
 
 
739
 
740
+ # yield from sd_gen.generate_pipeline(*generation_args)
 
 
 
 
 
 
741
 
742
+ start_time = time.time()
 
743
 
744
+ yield from dynamic_gpu_duration(
745
+ sd_gen.generate_pipeline,
746
+ gpu_duration_arg,
747
+ *generation_args,
748
+ )
749
+
750
+ end_time = time.time()
751
 
752
+ if verbose_arg:
753
+ execution_time = end_time - start_time
754
+ msg_task_complete = (
755
+ f"GPU task complete in: {round(execution_time, 0) + 1} seconds"
756
+ )
757
+ gr.Info(msg_task_complete)
758
+ print(msg_task_complete)
759
 
760
+
761
+ dynamic_gpu_duration.zerogpu = True
762
+ sd_gen_generate_pipeline.zerogpu = True
763
 
764
  from pathlib import Path
765
+ from PIL import Image
766
+ import random, json
767
  from modutils import (safe_float, escape_lora_basename, to_lora_key, to_lora_path,
768
  get_local_model_list, get_private_lora_model_lists, get_valid_lora_name,
769
  get_valid_lora_path, get_valid_lora_wt, get_lora_info,
 
779
  import numpy as np
780
  MAX_SEED = np.iinfo(np.int32).max
781
 
782
+ load_lora_cpu = False
783
+ verbose_info = False
784
+ gpu_duration = 59
785
+
786
  images: list[tuple[PIL.Image.Image, str | None]] = []
787
  info: str = ""
788
  progress(0, desc="Preparing...")
 
799
  prompt, negative_prompt = insert_model_recom_prompt(prompt, negative_prompt, model_name)
800
  progress(0.5, desc="Preparing...")
801
  lora1, lora1_wt, lora2, lora2_wt, lora3, lora3_wt, lora4, lora4_wt, lora5, lora5_wt = \
802
+ set_prompt_loras(prompt, model_name, lora1, lora1_wt, lora2, lora2_wt, lora3, lora3_wt, lora4, lora4_wt, lora5, lora5_wt)
803
  lora1 = get_valid_lora_path(lora1)
804
  lora2 = get_valid_lora_path(lora2)
805
  lora3 = get_valid_lora_path(lora3)
 
808
  progress(1, desc="Preparation completed. Starting inference preparation...")
809
 
810
  sd_gen.load_new_model(model_name, vae, TASK_MODEL_LIST[0], progress)
811
+ images, info = sd_gen_generate_pipeline(prompt, negative_prompt, 1, num_inference_steps,
812
  guidance_scale, True, generator, lora1, lora1_wt, lora2, lora2_wt, lora3, lora3_wt,
813
  lora4, lora4_wt, lora5, lora5_wt, sampler,
814
  height, width, model_name, vae, TASK_MODEL_LIST[0], None, "Canny", 512, 1024,
 
817
  False, True, 1, True, False, False, False, False, "./images", False, False, False, True, 1, 0.55,
818
  False, False, False, True, False, "Use same sampler", False, "", "", 0.35, True, True, False, 4, 4, 32,
819
  False, "", "", 0.35, True, True, False, 4, 4, 32,
820
+ True, None, None, "plus_face", "original", 0.7, None, None, "base", "style", 0.7, 0.0,
821
+ load_lora_cpu, verbose_info, gpu_duration
822
  )
823
 
824
  progress(1, desc="Inference completed.")
 
881
  if " " in repo_id or not api.repo_exists(repo_id): return ""
882
  model = api.model_info(repo_id=repo_id)
883
  except Exception as e:
884
+ print(f"Error: Failed to get {repo_id}'s info. {e}")
885
  return ""
886
  if model.private or model.gated: return ""
887
  tags = model.tags
 
1074
  gr.update(value=lora4, choices=choices), gr.update(value=lora5, choices=choices)
1075
 
1076
 
1077
+ def set_prompt_loras(prompt, model_name, lora1, lora1_wt, lora2, lora2_wt, lora3, lora3_wt, lora4, lora4_wt, lora5, lora5_wt):
1078
  import re
1079
+ lora1 = get_valid_lora_name(lora1, model_name)
1080
+ lora2 = get_valid_lora_name(lora2, model_name)
1081
+ lora3 = get_valid_lora_name(lora3, model_name)
1082
+ lora4 = get_valid_lora_name(lora4, model_name)
1083
+ lora5 = get_valid_lora_name(lora5, model_name)
1084
  if not "<lora" in prompt: return lora1, lora1_wt, lora2, lora2_wt, lora3, lora3_wt, lora4, lora4_wt, lora5, lora5_wt
1085
  lora1_wt = get_valid_lora_wt(prompt, lora1, lora1_wt)
1086
  lora2_wt = get_valid_lora_wt(prompt, lora2, lora2_wt)
 
1190
  gr.update(value=tag5, label=label5, visible=on5), gr.update(visible=on5), gr.update(value=md5, visible=on5)
1191
 
1192
 
1193
+ def search_civitai_lora(query, base_model, sort="Highest Rated", period="AllTime", tag=""):
1194
  global civitai_lora_last_results
1195
+ items = search_lora_on_civitai(query, base_model, 100, sort, period, tag)
1196
  if not items: return gr.update(choices=[("", "")], value="", visible=False),\
1197
  gr.update(value="", visible=False), gr.update(visible=True), gr.update(visible=True)
1198
  civitai_lora_last_results = {}
 
1385
  return gr.update(value=prompt), gr.update(value=neg_prompt)
1386
 
1387
 
 
1388
  def save_images(images: list[Image.Image], metadatas: list[str]):
1389
  from PIL import PngImagePlugin
1390
  try:
env.py CHANGED
@@ -102,6 +102,11 @@ load_diffusers_format_model = [
102
  "Raelina/Raemu-Flux",
103
  ]
104
 
 
 
 
 
 
105
  # List all Models for specified user
106
  HF_MODEL_USER_LIKES = ["votepurchase"] # sorted by number of likes
107
  HF_MODEL_USER_EX = ["John6666"] # sorted by a special rule
 
102
  "Raelina/Raemu-Flux",
103
  ]
104
 
105
+ DIFFUSERS_FORMAT_LORAS = [
106
+ "nerijs/animation2k-flux",
107
+ "XLabs-AI/flux-RealismLora",
108
+ ]
109
+
110
  # List all Models for specified user
111
  HF_MODEL_USER_LIKES = ["votepurchase"] # sorted by number of likes
112
  HF_MODEL_USER_EX = ["John6666"] # sorted by a special rule
llmdolphin.py CHANGED
@@ -56,8 +56,46 @@ llm_models = {
56
  "qwen2.5-lumen-14b-q4_k_m.gguf": ["Lambent/Qwen2.5-Lumen-14B-Q4_K_M-GGUF", MessagesFormatterType.OPEN_CHAT],
57
  "Qwen2.5-14B_Uncensored_Instruct.Q4_K_M.gguf": ["mradermacher/Qwen2.5-14B_Uncensored_Instruct-GGUF", MessagesFormatterType.OPEN_CHAT],
58
  "Trinas_Nectar-8B-model_stock.i1-Q4_K_M.gguf": ["mradermacher/Trinas_Nectar-8B-model_stock-i1-GGUF", MessagesFormatterType.MISTRAL],
 
59
  "ChatWaifu_v1.4.Q5_K_M.gguf": ["mradermacher/ChatWaifu_v1.4-GGUF", MessagesFormatterType.MISTRAL],
60
  "ChatWaifu_v1.3.1.Q4_K_M.gguf": ["mradermacher/ChatWaifu_v1.3.1-GGUF", MessagesFormatterType.MISTRAL],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  "Collaiborator-MEDLLM-Llama-3-8B-v1.i1-Q5_K_M.gguf": ["mradermacher/Collaiborator-MEDLLM-Llama-3-8B-v1-i1-GGUF", MessagesFormatterType.LLAMA_3],
62
  "Chili_Dog_8B.i1-Q4_K_M.gguf": ["mradermacher/Chili_Dog_8B-i1-GGUF", MessagesFormatterType.CHATML],
63
  "astra-v1-12b-q5_k_m.gguf": ["P0x0/Astra-v1-12B-GGUF", MessagesFormatterType.MISTRAL],
@@ -585,6 +623,8 @@ llm_models = {
585
  "tifa-7b-qwen2-v0.1.q4_k_m.gguf": ["Tifa-RP/Tifa-7B-Qwen2-v0.1-GGUF", MessagesFormatterType.OPEN_CHAT],
586
  "Holland-Magnum-Merge-R2.i1-Q5_K_M.gguf": ["mradermacher/Holland-Magnum-Merge-R2-i1-GGUF", MessagesFormatterType.LLAMA_3],
587
  "Oumuamua-7b-RP_Q5_K_M.gguf": ["Aratako/Oumuamua-7b-RP-GGUF", MessagesFormatterType.MISTRAL],
 
 
588
  "Berghof-NSFW-7B.Q5_K_M.gguf": ["QuantFactory/Berghof-NSFW-7B-GGUF", MessagesFormatterType.MISTRAL],
589
  "Japanese-TextGen-Kage-v0.1.2-2x7B-NSFW_iMat_Ch200_IQ4_XS.gguf": ["dddump/Japanese-TextGen-Kage-v0.1.2-2x7B-NSFW-gguf", MessagesFormatterType.VICUNA],
590
  "ChatWaifu_v1.2.1.Q5_K_M.gguf": ["mradermacher/ChatWaifu_v1.2.1-GGUF", MessagesFormatterType.MISTRAL],
 
56
  "qwen2.5-lumen-14b-q4_k_m.gguf": ["Lambent/Qwen2.5-Lumen-14B-Q4_K_M-GGUF", MessagesFormatterType.OPEN_CHAT],
57
  "Qwen2.5-14B_Uncensored_Instruct.Q4_K_M.gguf": ["mradermacher/Qwen2.5-14B_Uncensored_Instruct-GGUF", MessagesFormatterType.OPEN_CHAT],
58
  "Trinas_Nectar-8B-model_stock.i1-Q4_K_M.gguf": ["mradermacher/Trinas_Nectar-8B-model_stock-i1-GGUF", MessagesFormatterType.MISTRAL],
59
+ "ChatWaifu_22B_v2.0_preview.Q4_K_S.gguf": ["mradermacher/ChatWaifu_22B_v2.0_preview-GGUF", MessagesFormatterType.MISTRAL],
60
  "ChatWaifu_v1.4.Q5_K_M.gguf": ["mradermacher/ChatWaifu_v1.4-GGUF", MessagesFormatterType.MISTRAL],
61
  "ChatWaifu_v1.3.1.Q4_K_M.gguf": ["mradermacher/ChatWaifu_v1.3.1-GGUF", MessagesFormatterType.MISTRAL],
62
+ "Llama-3.2-3B-Instruct-uncensored.i1-Q5_K_S.gguf": ["mradermacher/Llama-3.2-3B-Instruct-uncensored-i1-GGUF", MessagesFormatterType.LLAMA_3],
63
+ "L3-SthenoMaid-8B-V1-Q5_K_M.gguf": ["bartowski/L3-SthenoMaid-8B-V1-GGUF", MessagesFormatterType.LLAMA_3],
64
+ "Magot-v2-Gemma2-8k-9B.Q5_K_M.gguf": ["grimjim/Magot-v2-Gemma2-8k-9B-GGUF", MessagesFormatterType.ALPACA],
65
+ "Dante_9B.i1-Q4_K_M.gguf": ["mradermacher/Dante_9B-i1-GGUF", MessagesFormatterType.ALPACA],
66
+ "L3.1-Artemis-faustus-8B.i1-Q5_K_M.gguf": ["mradermacher/L3.1-Artemis-faustus-8B-i1-GGUF", MessagesFormatterType.LLAMA_3],
67
+ "EdgeRunner-Command-Nested-FC-v3.i1-Q4_K_M.gguf": ["mradermacher/EdgeRunner-Command-Nested-FC-v3-i1-GGUF", MessagesFormatterType.OPEN_CHAT],
68
+ "Virgil_9B.i1-Q4_K_M.gguf": ["mradermacher/Virgil_9B-i1-GGUF", MessagesFormatterType.ALPACA],
69
+ "L3.1-Noraian.i1-Q5_K_M.gguf": ["mradermacher/L3.1-Noraian-i1-GGUF", MessagesFormatterType.LLAMA_3],
70
+ "DARKER-PLANET-Broken-Land-12.15B-D_AU-Q4_k_m.gguf": ["DavidAU/DARKER-PLANET-Broken-Land-12.15B-GGUF", MessagesFormatterType.LLAMA_3],
71
+ "L3.1-Purosani.i1-Q4_K_M.gguf": ["mradermacher/L3.1-Purosani-i1-GGUF", MessagesFormatterType.LLAMA_3],
72
+ "L3-Darker-Planet-12.15B.i1-Q4_K_S.gguf": ["mradermacher/L3-Darker-Planet-12.15B-i1-GGUF", MessagesFormatterType.LLAMA_3],
73
+ "MFANN-llama3.1-Abliterated-SLERP.i1-Q5_K_M.gguf": ["mradermacher/MFANN-llama3.1-Abliterated-SLERP-i1-GGUF", MessagesFormatterType.LLAMA_3],
74
+ "arsenic-v1-qwen2.5-14B.Q4_K_M.gguf": ["mradermacher/arsenic-v1-qwen2.5-14B-GGUF", MessagesFormatterType.OPEN_CHAT],
75
+ "Mistral-Nemo-Gutenberg-Doppel-12B.Q4_K_M.gguf": ["mradermacher/Mistral-Nemo-Gutenberg-Doppel-12B-GGUF", MessagesFormatterType.MISTRAL],
76
+ "Gemma-The-Writer-9B.Q4_K_M.gguf": ["mradermacher/Gemma-The-Writer-9B-GGUF", MessagesFormatterType.ALPACA],
77
+ "Qwen2.5-14B-Instruct-abliterated.i1-Q4_K_M.gguf": ["mradermacher/Qwen2.5-14B-Instruct-abliterated-i1-GGUF", MessagesFormatterType.OPEN_CHAT],
78
+ "LongCite-llama3.1-8B-Q5_K_M.gguf": ["LPN64/LongCite-llama3.1-8b-GGUF", MessagesFormatterType.LLAMA_3],
79
+ "IceDrinkNameGoesHereV0RP-7b-Model_Stock.Q4_K_M.gguf": ["mradermacher/IceDrinkNameGoesHereV0RP-7b-Model_Stock-GGUF", MessagesFormatterType.MISTRAL],
80
+ "Magnum-Picaro-0.7-v2-12b.i1-Q4_K_M.gguf": ["mradermacher/Magnum-Picaro-0.7-v2-12b-i1-GGUF", MessagesFormatterType.CHATML],
81
+ "arsenic-v1-qwen2.5-14b-q4_k_m.gguf": ["Lambent/arsenic-v1-qwen2.5-14B-Q4_K_M-GGUF", MessagesFormatterType.OPEN_CHAT],
82
+ "magnum-v3-9b-chatml.i1-Q4_K_M.gguf": ["mradermacher/magnum-v3-9b-chatml-i1-GGUF", MessagesFormatterType.CHATML],
83
+ "Aether-12b.Q4_K_M.gguf": ["mradermacher/Aether-12b-GGUF", MessagesFormatterType.MISTRAL],
84
+ "Magnum-Picaro-0.7-v2-12b.Q4_K_M.gguf": ["mradermacher/Magnum-Picaro-0.7-v2-12b-GGUF", MessagesFormatterType.CHATML],
85
+ "IceMartiniV1RP-7b.i1-Q4_K_M.gguf": ["mradermacher/IceMartiniV1RP-7b-i1-GGUF", MessagesFormatterType.MISTRAL],
86
+ "L3-Darker-Planet-12.15B-D_AU-Q4_k_s.gguf": ["DavidAU/L3-Darker-Planet-12.15B-GGUF", MessagesFormatterType.LLAMA_3],
87
+ "Llama-3-8B-Stroganoff-4.0-Version-B.i1-Q4_K_M.gguf": ["mradermacher/Llama-3-8B-Stroganoff-4.0-Version-B-i1-GGUF", MessagesFormatterType.LLAMA_3],
88
+ "IceDrunkCherryV1RP-7b.Q5_K_S.gguf": ["mradermacher/IceDrunkCherryV1RP-7b-GGUF", MessagesFormatterType.MISTRAL],
89
+ "Llama-3-8B-Stroganoff-4.0-Version-A.i1-Q4_K_M.gguf": ["mradermacher/Llama-3-8B-Stroganoff-4.0-Version-A-i1-GGUF", MessagesFormatterType.LLAMA_3],
90
+ "vulca-reshapetesting-006.q8_0.gguf": ["kromquant/vulca-reshapetesting-006-GGUFs", MessagesFormatterType.MISTRAL],
91
+ "FatGirl_v2_DPO_v2_8B.Q4_K_M.gguf": ["mradermacher/FatGirl_v2_DPO_v2_8B-GGUF", MessagesFormatterType.MISTRAL],
92
+ "NemonsterExtreme-12b.Q4_K_M.gguf": ["mradermacher/NemonsterExtreme-12b-GGUF", MessagesFormatterType.MISTRAL],
93
+ "Magnum-Instruct-DPO-18B.Q4_K_M.gguf": ["mradermacher/Magnum-Instruct-DPO-18B-GGUF", MessagesFormatterType.MISTRAL],
94
+ "qwen2.5-reinstruct-alternate-lumen-14B.Q4_K_M.gguf": ["mradermacher/qwen2.5-reinstruct-alternate-lumen-14B-GGUF", MessagesFormatterType.OPEN_CHAT],
95
+ "SuperHeart.i1-Q5_K_M.gguf": ["mradermacher/SuperHeart-i1-GGUF", MessagesFormatterType.LLAMA_3],
96
+ "SuperHeartBerg.i1-Q4_K_M.gguf": ["mradermacher/SuperHeartBerg-i1-GGUF", MessagesFormatterType.LLAMA_3],
97
+ "gemma-2-9B-it-function-calling-Q5_K_M.gguf": ["DiTy/gemma-2-9b-it-function-calling-GGUF", MessagesFormatterType.ALPACA],
98
+ "Qwen2.5-7B-TitanFusion.Q5_K_M.gguf": ["mradermacher/Qwen2.5-7B-TitanFusion-GGUF", MessagesFormatterType.OPEN_CHAT],
99
  "Collaiborator-MEDLLM-Llama-3-8B-v1.i1-Q5_K_M.gguf": ["mradermacher/Collaiborator-MEDLLM-Llama-3-8B-v1-i1-GGUF", MessagesFormatterType.LLAMA_3],
100
  "Chili_Dog_8B.i1-Q4_K_M.gguf": ["mradermacher/Chili_Dog_8B-i1-GGUF", MessagesFormatterType.CHATML],
101
  "astra-v1-12b-q5_k_m.gguf": ["P0x0/Astra-v1-12B-GGUF", MessagesFormatterType.MISTRAL],
 
623
  "tifa-7b-qwen2-v0.1.q4_k_m.gguf": ["Tifa-RP/Tifa-7B-Qwen2-v0.1-GGUF", MessagesFormatterType.OPEN_CHAT],
624
  "Holland-Magnum-Merge-R2.i1-Q5_K_M.gguf": ["mradermacher/Holland-Magnum-Merge-R2-i1-GGUF", MessagesFormatterType.LLAMA_3],
625
  "Oumuamua-7b-RP_Q5_K_M.gguf": ["Aratako/Oumuamua-7b-RP-GGUF", MessagesFormatterType.MISTRAL],
626
+ "ContextualKunoichi_KTO-7B.Q5_K_M.gguf": ["mradermacher/ContextualKunoichi_KTO-7B-GGUF", MessagesFormatterType.MISTRAL],
627
+ "ContextualToppy_KTO-7B.Q5_K_M.gguf": ["mradermacher/ContextualToppy_KTO-7B-GGUF", MessagesFormatterType.MISTRAL],
628
  "Berghof-NSFW-7B.Q5_K_M.gguf": ["QuantFactory/Berghof-NSFW-7B-GGUF", MessagesFormatterType.MISTRAL],
629
  "Japanese-TextGen-Kage-v0.1.2-2x7B-NSFW_iMat_Ch200_IQ4_XS.gguf": ["dddump/Japanese-TextGen-Kage-v0.1.2-2x7B-NSFW-gguf", MessagesFormatterType.VICUNA],
630
  "ChatWaifu_v1.2.1.Q5_K_M.gguf": ["mradermacher/ChatWaifu_v1.2.1-GGUF", MessagesFormatterType.MISTRAL],
modutils.py CHANGED
@@ -4,13 +4,21 @@ import gradio as gr
4
  from huggingface_hub import HfApi
5
  import os
6
  from pathlib import Path
 
7
 
8
 
9
  from env import (HF_LORA_PRIVATE_REPOS1, HF_LORA_PRIVATE_REPOS2,
10
- HF_MODEL_USER_EX, HF_MODEL_USER_LIKES,
11
  directory_loras, hf_read_token, HF_TOKEN, CIVITAI_API_KEY)
12
 
13
 
 
 
 
 
 
 
 
14
  def get_user_agent():
15
  return 'Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:127.0) Gecko/20100101 Firefox/127.0'
16
 
@@ -27,6 +35,11 @@ def list_sub(a, b):
27
  return [e for e in a if e not in b]
28
 
29
 
 
 
 
 
 
30
  from translatepy import Translator
31
  translator = Translator()
32
  def translate_to_en(input: str):
@@ -64,7 +77,7 @@ def download_things(directory, url, hf_token="", civitai_api_key=""):
64
  if hf_token:
65
  os.system(f"aria2c --console-log-level=error --summary-interval=10 --header={user_header} -c -x 16 -k 1M -s 16 {url} -d {directory} -o {url.split('/')[-1]}")
66
  else:
67
- os.system (f"aria2c --optimize-concurrent-downloads --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 {url} -d {directory} -o {url.split('/')[-1]}")
68
  elif "civitai.com" in url:
69
  if "?" in url:
70
  url = url.split("?")[0]
@@ -100,7 +113,6 @@ def safe_float(input):
100
  return output
101
 
102
 
103
- from PIL import Image
104
  def save_images(images: list[Image.Image], metadatas: list[str]):
105
  from PIL import PngImagePlugin
106
  import uuid
@@ -245,10 +257,10 @@ model_id_list = get_model_id_list()
245
 
246
 
247
  def get_t2i_model_info(repo_id: str):
248
- api = HfApi()
249
  try:
250
- if " " in repo_id or not api.repo_exists(repo_id): return ""
251
- model = api.model_info(repo_id=repo_id)
252
  except Exception as e:
253
  print(f"Error: Failed to get {repo_id}'s info.")
254
  print(e)
@@ -258,9 +270,8 @@ def get_t2i_model_info(repo_id: str):
258
  info = []
259
  url = f"https://huggingface.co/{repo_id}/"
260
  if not 'diffusers' in tags: return ""
261
- if 'diffusers:FluxPipeline' in tags: info.append("FLUX.1")
262
- elif 'diffusers:StableDiffusionXLPipeline' in tags: info.append("SDXL")
263
- elif 'diffusers:StableDiffusionPipeline' in tags: info.append("SD1.5")
264
  if model.card_data and model.card_data.tags:
265
  info.extend(list_sub(model.card_data.tags, ['text-to-image', 'stable-diffusion', 'stable-diffusion-api', 'safetensors', 'stable-diffusion-xl']))
266
  info.append(f"DLs: {model.downloads}")
@@ -285,12 +296,8 @@ def get_tupled_model_list(model_list):
285
  tags = model.tags
286
  info = []
287
  if not 'diffusers' in tags: continue
288
- if 'diffusers:FluxPipeline' in tags:
289
- info.append("FLUX.1")
290
- if 'diffusers:StableDiffusionXLPipeline' in tags:
291
- info.append("SDXL")
292
- elif 'diffusers:StableDiffusionPipeline' in tags:
293
- info.append("SD1.5")
294
  if model.card_data and model.card_data.tags:
295
  info.extend(list_sub(model.card_data.tags, ['text-to-image', 'stable-diffusion', 'stable-diffusion-api', 'safetensors', 'stable-diffusion-xl']))
296
  if "pony" in info:
@@ -374,7 +381,7 @@ def get_civitai_info(path):
374
 
375
 
376
  def get_lora_model_list():
377
- loras = list_uniq(get_private_lora_model_lists() + get_local_model_list(directory_loras))
378
  loras.insert(0, "None")
379
  loras.insert(0, "")
380
  return loras
@@ -483,7 +490,7 @@ def download_my_lora(dl_urls: str, lora1: str, lora2: str, lora3: str, lora4: st
483
  gr.update(value=lora4, choices=choices), gr.update(value=lora5, choices=choices)
484
 
485
 
486
- def get_valid_lora_name(query: str):
487
  path = "None"
488
  if not query or query == "None": return "None"
489
  if to_lora_key(query) in loras_dict.keys(): return query
@@ -497,7 +504,7 @@ def get_valid_lora_name(query: str):
497
  dl_file = download_lora(query)
498
  if dl_file and Path(dl_file).exists(): return dl_file
499
  else:
500
- dl_file = find_similar_lora(query)
501
  if dl_file and Path(dl_file).exists(): return dl_file
502
  return "None"
503
 
@@ -521,14 +528,14 @@ def get_valid_lora_wt(prompt: str, lora_path: str, lora_wt: float):
521
  return wt
522
 
523
 
524
- def set_prompt_loras(prompt, prompt_syntax, lora1, lora1_wt, lora2, lora2_wt, lora3, lora3_wt, lora4, lora4_wt, lora5, lora5_wt):
525
  import re
526
  if not "Classic" in str(prompt_syntax): return lora1, lora1_wt, lora2, lora2_wt, lora3, lora3_wt, lora4, lora4_wt, lora5, lora5_wt
527
- lora1 = get_valid_lora_name(lora1)
528
- lora2 = get_valid_lora_name(lora2)
529
- lora3 = get_valid_lora_name(lora3)
530
- lora4 = get_valid_lora_name(lora4)
531
- lora5 = get_valid_lora_name(lora5)
532
  if not "<lora" in prompt: return lora1, lora1_wt, lora2, lora2_wt, lora3, lora3_wt, lora4, lora4_wt, lora5, lora5_wt
533
  lora1_wt = get_valid_lora_wt(prompt, lora1, lora1_wt)
534
  lora2_wt = get_valid_lora_wt(prompt, lora2, lora2_wt)
@@ -790,16 +797,17 @@ def get_civitai_info(path):
790
  return items
791
 
792
 
793
- def search_lora_on_civitai(query: str, allow_model: list[str] = ["Pony", "SDXL 1.0"], limit: int = 100):
 
794
  import requests
795
  from requests.adapters import HTTPAdapter
796
  from urllib3.util import Retry
797
- if not query: return None
798
  user_agent = get_user_agent()
799
  headers = {'User-Agent': user_agent, 'content-type': 'application/json'}
800
  base_url = 'https://civitai.com/api/v1/models'
801
- params = {'query': query, 'types': ['LORA'], 'sort': 'Highest Rated', 'period': 'AllTime',
802
- 'nsfw': 'true', 'supportsGeneration ': 'true', 'limit': limit}
 
803
  session = requests.Session()
804
  retries = Retry(total=5, backoff_factor=1, status_forcelist=[500, 502, 503, 504])
805
  session.mount("https://", HTTPAdapter(max_retries=retries))
@@ -828,9 +836,9 @@ def search_lora_on_civitai(query: str, allow_model: list[str] = ["Pony", "SDXL 1
828
  return items
829
 
830
 
831
- def search_civitai_lora(query, base_model):
832
  global civitai_lora_last_results
833
- items = search_lora_on_civitai(query, base_model)
834
  if not items: return gr.update(choices=[("", "")], value="", visible=False),\
835
  gr.update(value="", visible=False), gr.update(visible=True), gr.update(visible=True)
836
  civitai_lora_last_results = {}
@@ -856,7 +864,27 @@ def select_civitai_lora(search_result):
856
  return gr.update(value=search_result), gr.update(value=md, visible=True)
857
 
858
 
859
- def find_similar_lora(q: str):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
860
  from rapidfuzz.process import extractOne
861
  from rapidfuzz.utils import default_process
862
  query = to_lora_key(q)
@@ -879,7 +907,7 @@ def find_similar_lora(q: str):
879
  print(f"Finding <lora:{query}:...> on Civitai...")
880
  civitai_query = Path(query).stem if Path(query).is_file() else query
881
  civitai_query = civitai_query.replace("_", " ").replace("-", " ")
882
- base_model = ["Pony", "SDXL 1.0"]
883
  items = search_lora_on_civitai(civitai_query, base_model, 1)
884
  if items:
885
  item = items[0]
@@ -1241,11 +1269,11 @@ def set_textual_inversion_prompt(textual_inversion_gui, prompt_gui, neg_prompt_g
1241
 
1242
  def get_model_pipeline(repo_id: str):
1243
  from huggingface_hub import HfApi
1244
- api = HfApi()
1245
  default = "StableDiffusionPipeline"
1246
  try:
1247
- if " " in repo_id or not api.repo_exists(repo_id): return default
1248
- model = api.model_info(repo_id=repo_id)
1249
  except Exception:
1250
  return default
1251
  if model.private or model.gated: return default
 
4
  from huggingface_hub import HfApi
5
  import os
6
  from pathlib import Path
7
+ from PIL import Image
8
 
9
 
10
  from env import (HF_LORA_PRIVATE_REPOS1, HF_LORA_PRIVATE_REPOS2,
11
+ HF_MODEL_USER_EX, HF_MODEL_USER_LIKES, DIFFUSERS_FORMAT_LORAS,
12
  directory_loras, hf_read_token, HF_TOKEN, CIVITAI_API_KEY)
13
 
14
 
15
+ MODEL_TYPE_DICT = {
16
+ "diffusers:StableDiffusionPipeline": "SD 1.5",
17
+ "diffusers:StableDiffusionXLPipeline": "SDXL",
18
+ "diffusers:FluxPipeline": "FLUX",
19
+ }
20
+
21
+
22
  def get_user_agent():
23
  return 'Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:127.0) Gecko/20100101 Firefox/127.0'
24
 
 
35
  return [e for e in a if e not in b]
36
 
37
 
38
+ def is_repo_name(s):
39
+ import re
40
+ return re.fullmatch(r'^[^/]+?/[^/]+?$', s)
41
+
42
+
43
  from translatepy import Translator
44
  translator = Translator()
45
  def translate_to_en(input: str):
 
77
  if hf_token:
78
  os.system(f"aria2c --console-log-level=error --summary-interval=10 --header={user_header} -c -x 16 -k 1M -s 16 {url} -d {directory} -o {url.split('/')[-1]}")
79
  else:
80
+ os.system(f"aria2c --optimize-concurrent-downloads --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 {url} -d {directory} -o {url.split('/')[-1]}")
81
  elif "civitai.com" in url:
82
  if "?" in url:
83
  url = url.split("?")[0]
 
113
  return output
114
 
115
 
 
116
  def save_images(images: list[Image.Image], metadatas: list[str]):
117
  from PIL import PngImagePlugin
118
  import uuid
 
257
 
258
 
259
  def get_t2i_model_info(repo_id: str):
260
+ api = HfApi(token=HF_TOKEN)
261
  try:
262
+ if not is_repo_name(repo_id): return ""
263
+ model = api.model_info(repo_id=repo_id, timeout=5.0)
264
  except Exception as e:
265
  print(f"Error: Failed to get {repo_id}'s info.")
266
  print(e)
 
270
  info = []
271
  url = f"https://huggingface.co/{repo_id}/"
272
  if not 'diffusers' in tags: return ""
273
+ for k, v in MODEL_TYPE_DICT.items():
274
+ if k in tags: info.append(v)
 
275
  if model.card_data and model.card_data.tags:
276
  info.extend(list_sub(model.card_data.tags, ['text-to-image', 'stable-diffusion', 'stable-diffusion-api', 'safetensors', 'stable-diffusion-xl']))
277
  info.append(f"DLs: {model.downloads}")
 
296
  tags = model.tags
297
  info = []
298
  if not 'diffusers' in tags: continue
299
+ for k, v in MODEL_TYPE_DICT.items():
300
+ if k in tags: info.append(v)
 
 
 
 
301
  if model.card_data and model.card_data.tags:
302
  info.extend(list_sub(model.card_data.tags, ['text-to-image', 'stable-diffusion', 'stable-diffusion-api', 'safetensors', 'stable-diffusion-xl']))
303
  if "pony" in info:
 
381
 
382
 
383
  def get_lora_model_list():
384
+ loras = list_uniq(get_private_lora_model_lists() + get_local_model_list(directory_loras) + DIFFUSERS_FORMAT_LORAS)
385
  loras.insert(0, "None")
386
  loras.insert(0, "")
387
  return loras
 
490
  gr.update(value=lora4, choices=choices), gr.update(value=lora5, choices=choices)
491
 
492
 
493
+ def get_valid_lora_name(query: str, model_name: str):
494
  path = "None"
495
  if not query or query == "None": return "None"
496
  if to_lora_key(query) in loras_dict.keys(): return query
 
504
  dl_file = download_lora(query)
505
  if dl_file and Path(dl_file).exists(): return dl_file
506
  else:
507
+ dl_file = find_similar_lora(query, model_name)
508
  if dl_file and Path(dl_file).exists(): return dl_file
509
  return "None"
510
 
 
528
  return wt
529
 
530
 
531
+ def set_prompt_loras(prompt, prompt_syntax, model_name, lora1, lora1_wt, lora2, lora2_wt, lora3, lora3_wt, lora4, lora4_wt, lora5, lora5_wt):
532
  import re
533
  if not "Classic" in str(prompt_syntax): return lora1, lora1_wt, lora2, lora2_wt, lora3, lora3_wt, lora4, lora4_wt, lora5, lora5_wt
534
+ lora1 = get_valid_lora_name(lora1, model_name)
535
+ lora2 = get_valid_lora_name(lora2, model_name)
536
+ lora3 = get_valid_lora_name(lora3, model_name)
537
+ lora4 = get_valid_lora_name(lora4, model_name)
538
+ lora5 = get_valid_lora_name(lora5, model_name)
539
  if not "<lora" in prompt: return lora1, lora1_wt, lora2, lora2_wt, lora3, lora3_wt, lora4, lora4_wt, lora5, lora5_wt
540
  lora1_wt = get_valid_lora_wt(prompt, lora1, lora1_wt)
541
  lora2_wt = get_valid_lora_wt(prompt, lora2, lora2_wt)
 
797
  return items
798
 
799
 
800
+ def search_lora_on_civitai(query: str, allow_model: list[str] = ["Pony", "SDXL 1.0"], limit: int = 100,
801
+ sort: str = "Highest Rated", period: str = "AllTime", tag: str = ""):
802
  import requests
803
  from requests.adapters import HTTPAdapter
804
  from urllib3.util import Retry
 
805
  user_agent = get_user_agent()
806
  headers = {'User-Agent': user_agent, 'content-type': 'application/json'}
807
  base_url = 'https://civitai.com/api/v1/models'
808
+ params = {'types': ['LORA'], 'sort': sort, 'period': period, 'limit': limit, 'nsfw': 'true'}
809
+ if query: params["query"] = query
810
+ if tag: params["tag"] = tag
811
  session = requests.Session()
812
  retries = Retry(total=5, backoff_factor=1, status_forcelist=[500, 502, 503, 504])
813
  session.mount("https://", HTTPAdapter(max_retries=retries))
 
836
  return items
837
 
838
 
839
+ def search_civitai_lora(query, base_model, sort="Highest Rated", period="AllTime", tag=""):
840
  global civitai_lora_last_results
841
+ items = search_lora_on_civitai(query, base_model, 100, sort, period, tag)
842
  if not items: return gr.update(choices=[("", "")], value="", visible=False),\
843
  gr.update(value="", visible=False), gr.update(visible=True), gr.update(visible=True)
844
  civitai_lora_last_results = {}
 
864
  return gr.update(value=search_result), gr.update(value=md, visible=True)
865
 
866
 
867
+ LORA_BASE_MODEL_DICT = {
868
+ "diffusers:StableDiffusionPipeline": ["SD 1.5"],
869
+ "diffusers:StableDiffusionXLPipeline": ["Pony", "SDXL 1.0"],
870
+ "diffusers:FluxPipeline": ["Flux.1 D", "Flux.1 S"],
871
+ }
872
+
873
+
874
+ def get_lora_base_model(model_name: str):
875
+ api = HfApi(token=HF_TOKEN)
876
+ default = ["Pony", "SDXL 1.0"]
877
+ try:
878
+ model = api.model_info(repo_id=model_name, timeout=5.0)
879
+ tags = model.tags
880
+ for tag in tags:
881
+ if tag in LORA_BASE_MODEL_DICT.keys(): return LORA_BASE_MODEL_DICT.get(tag, default)
882
+ except Exception:
883
+ return default
884
+ return default
885
+
886
+
887
+ def find_similar_lora(q: str, model_name: str):
888
  from rapidfuzz.process import extractOne
889
  from rapidfuzz.utils import default_process
890
  query = to_lora_key(q)
 
907
  print(f"Finding <lora:{query}:...> on Civitai...")
908
  civitai_query = Path(query).stem if Path(query).is_file() else query
909
  civitai_query = civitai_query.replace("_", " ").replace("-", " ")
910
+ base_model = get_lora_base_model(model_name)
911
  items = search_lora_on_civitai(civitai_query, base_model, 1)
912
  if items:
913
  item = items[0]
 
1269
 
1270
  def get_model_pipeline(repo_id: str):
1271
  from huggingface_hub import HfApi
1272
+ api = HfApi(token=HF_TOKEN)
1273
  default = "StableDiffusionPipeline"
1274
  try:
1275
+ if not is_repo_name(repo_id): return default
1276
+ model = api.model_info(repo_id=repo_id, timeout=5.0)
1277
  except Exception:
1278
  return default
1279
  if model.private or model.gated: return default
requirements.txt CHANGED
@@ -8,7 +8,6 @@ git+https://github.com/R3gm/stablepy.git@flux_beta
8
  torch==2.2.0
9
  gdown
10
  opencv-python
11
- yt-dlp
12
  huggingface_hub
13
  scikit-build-core
14
  https://github.com/abetlen/llama-cpp-python/releases/download/v0.2.90-cu124/llama_cpp_python-0.2.90-cp310-cp310-linux_x86_64.whl
 
8
  torch==2.2.0
9
  gdown
10
  opencv-python
 
11
  huggingface_hub
12
  scikit-build-core
13
  https://github.com/abetlen/llama-cpp-python/releases/download/v0.2.90-cu124/llama_cpp_python-0.2.90-cp310-cp310-linux_x86_64.whl