John6666 commited on
Commit
b7778c4
1 Parent(s): d34f363

Upload 6 files

Browse files
Files changed (3) hide show
  1. dc.py +135 -137
  2. env.py +1 -1
  3. modutils.py +12 -8
dc.py CHANGED
@@ -1,12 +1,11 @@
1
  import spaces
2
  import os
3
  from stablepy import Model_Diffusers
4
- from stablepy.diffusers_vanilla.model import scheduler_names
5
  from stablepy.diffusers_vanilla.style_prompt_config import STYLE_NAMES
 
6
  import torch
7
  import re
8
- import shutil
9
- import random
10
  from stablepy import (
11
  CONTROLNET_MODEL_IDS,
12
  VALID_TASKS,
@@ -22,7 +21,7 @@ from stablepy import (
22
  SD15_TASKS,
23
  SDXL_TASKS,
24
  )
25
- import urllib.parse
26
  import gradio as gr
27
  from PIL import Image
28
  import IPython.display
@@ -40,7 +39,7 @@ from stablepy import logger
40
  logger.setLevel(logging.CRITICAL)
41
 
42
  from env import (
43
- hf_token, hf_read_token, # to use only for private repos
44
  CIVITAI_API_KEY, HF_LORA_PRIVATE_REPOS1, HF_LORA_PRIVATE_REPOS2,
45
  HF_LORA_ESSENTIAL_PRIVATE_REPO, HF_VAE_PRIVATE_REPO,
46
  HF_SDXL_EMBEDS_NEGATIVE_PRIVATE_REPO, HF_SDXL_EMBEDS_POSITIVE_PRIVATE_REPO,
@@ -49,7 +48,7 @@ from env import (
49
  load_diffusers_format_model, download_model_list, download_lora_list,
50
  download_vae_list, download_embeds)
51
 
52
- preprocessor_controlnet = {
53
  "openpose": [
54
  "Openpose",
55
  "None",
@@ -121,7 +120,7 @@ preprocessor_controlnet = {
121
  ],
122
  }
123
 
124
- task_stablepy = {
125
  'txt2img': 'txt2img',
126
  'img2img': 'img2img',
127
  'inpaint': 'inpaint',
@@ -147,7 +146,35 @@ task_stablepy = {
147
  'tile ControlNet': 'tile',
148
  }
149
 
150
- task_model_list = list(task_stablepy.keys())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
 
152
  def download_things(directory, url, hf_token="", civitai_api_key=""):
153
  url = url.strip()
@@ -178,21 +205,19 @@ def download_things(directory, url, hf_token="", civitai_api_key=""):
178
  else:
179
  os.system(f"aria2c --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 -d {directory} {url}")
180
 
181
-
182
  def get_model_list(directory_path):
183
  model_list = []
184
  valid_extensions = {'.ckpt', '.pt', '.pth', '.safetensors', '.bin'}
185
 
186
  for filename in os.listdir(directory_path):
187
  if os.path.splitext(filename)[1] in valid_extensions:
188
- name_without_extension = os.path.splitext(filename)[0]
189
  file_path = os.path.join(directory_path, filename)
190
  # model_list.append((name_without_extension, file_path))
191
  model_list.append(file_path)
192
  print('\033[34mFILE: ' + file_path + '\033[0m')
193
  return model_list
194
 
195
-
196
  ## BEGIN MOD
197
  from modutils import (to_list, list_uniq, list_sub, get_model_id_list, get_tupled_embed_list,
198
  get_tupled_model_list, get_lora_model_list, download_private_repo)
@@ -210,24 +235,21 @@ download_private_repo(HF_VAE_PRIVATE_REPO, directory_vaes, False)
210
  load_diffusers_format_model = list_uniq(load_diffusers_format_model + get_model_id_list())
211
  ## END MOD
212
 
213
- CIVITAI_API_KEY = os.environ.get("CIVITAI_API_KEY")
214
- hf_token = os.environ.get("HF_TOKEN")
215
-
216
  # Download stuffs
217
  for url in [url.strip() for url in download_model.split(',')]:
218
  if not os.path.exists(f"./models/{url.split('/')[-1]}"):
219
- download_things(directory_models, url, hf_token, CIVITAI_API_KEY)
220
  for url in [url.strip() for url in download_vae.split(',')]:
221
  if not os.path.exists(f"./vaes/{url.split('/')[-1]}"):
222
- download_things(directory_vaes, url, hf_token, CIVITAI_API_KEY)
223
  for url in [url.strip() for url in download_lora.split(',')]:
224
  if not os.path.exists(f"./loras/{url.split('/')[-1]}"):
225
- download_things(directory_loras, url, hf_token, CIVITAI_API_KEY)
226
 
227
  # Download Embeddings
228
  for url_embed in download_embeds:
229
  if not os.path.exists(f"./embedings/{url_embed.split('/')[-1]}"):
230
- download_things(directory_embeds, url_embed, hf_token, CIVITAI_API_KEY)
231
 
232
  # Build list models
233
  embed_list = get_model_list(directory_embeds)
@@ -244,53 +266,45 @@ embed_sdxl_list = get_model_list(directory_embeds_sdxl) + get_model_list(directo
244
 
245
  def get_embed_list(pipeline_name):
246
  return get_tupled_embed_list(embed_sdxl_list if pipeline_name == "StableDiffusionXLPipeline" else embed_list)
247
-
248
-
249
  ## END MOD
250
 
251
  print('\033[33m🏁 Download and listing of valid models completed.\033[0m')
252
 
253
- upscaler_dict_gui = {
254
- None: None,
255
- "Lanczos": "Lanczos",
256
- "Nearest": "Nearest",
257
- 'Latent': 'Latent',
258
- 'Latent (antialiased)': 'Latent (antialiased)',
259
- 'Latent (bicubic)': 'Latent (bicubic)',
260
- 'Latent (bicubic antialiased)': 'Latent (bicubic antialiased)',
261
- 'Latent (nearest)': 'Latent (nearest)',
262
- 'Latent (nearest-exact)': 'Latent (nearest-exact)',
263
- "RealESRGAN_x4plus": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
264
- "RealESRNet_x4plus": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth",
265
- "RealESRGAN_x4plus_anime_6B": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
266
- "RealESRGAN_x2plus": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
267
- "realesr-animevideov3": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth",
268
- "realesr-general-x4v3": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth",
269
- "realesr-general-wdn-x4v3": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth",
270
- "4x-UltraSharp": "https://huggingface.co/Shandypur/ESRGAN-4x-UltraSharp/resolve/main/4x-UltraSharp.pth",
271
- "4x_foolhardy_Remacri": "https://huggingface.co/FacehugmanIII/4x_foolhardy_Remacri/resolve/main/4x_foolhardy_Remacri.pth",
272
- "Remacri4xExtraSmoother": "https://huggingface.co/hollowstrawberry/upscalers-backup/resolve/main/ESRGAN/Remacri%204x%20ExtraSmoother.pth",
273
- "AnimeSharp4x": "https://huggingface.co/hollowstrawberry/upscalers-backup/resolve/main/ESRGAN/AnimeSharp%204x.pth",
274
- "lollypop": "https://huggingface.co/hollowstrawberry/upscalers-backup/resolve/main/ESRGAN/lollypop.pth",
275
- "RealisticRescaler4x": "https://huggingface.co/hollowstrawberry/upscalers-backup/resolve/main/ESRGAN/RealisticRescaler%204x.pth",
276
- "NickelbackFS4x": "https://huggingface.co/hollowstrawberry/upscalers-backup/resolve/main/ESRGAN/NickelbackFS%204x.pth"
277
  }
278
 
279
- upscaler_keys = list(upscaler_dict_gui.keys())
 
 
 
 
280
 
 
281
 
282
  def extract_parameters(input_string):
283
  parameters = {}
284
  input_string = input_string.replace("\n", "")
285
 
286
- if not "Negative prompt:" in input_string:
287
  print("Negative prompt not detected")
288
  parameters["prompt"] = input_string
289
  return parameters
290
 
291
  parm = input_string.split("Negative prompt:")
292
  parameters["prompt"] = parm[0]
293
- if not "Steps:" in parm[1]:
294
  print("Steps not detected")
295
  parameters["neg_prompt"] = parm[1]
296
  return parameters
@@ -318,6 +332,17 @@ def extract_parameters(input_string):
318
 
319
  return parameters
320
 
 
 
 
 
 
 
 
 
 
 
 
321
 
322
  ## BEGIN MOD
323
  class GuiSD:
@@ -348,29 +373,27 @@ class GuiSD:
348
 
349
  def load_new_model(self, model_name, vae_model, task, progress=gr.Progress(track_tqdm=True)):
350
 
351
- yield f"Loading model: {model_name}"
352
 
353
  vae_model = vae_model if vae_model != "None" else None
 
354
 
355
- if model_name in model_list:
356
- model_is_xl = "xl" in model_name.lower()
357
- sdxl_in_vae = vae_model and "sdxl" in vae_model.lower()
358
- model_type = "SDXL" if model_is_xl else "SD 1.5"
359
- incompatible_vae = (model_is_xl and vae_model and not sdxl_in_vae) or (not model_is_xl and sdxl_in_vae)
360
-
361
- if incompatible_vae:
362
- vae_model = None
363
 
364
  self.model.device = torch.device("cpu")
 
365
 
366
  self.model.load_pipe(
367
  model_name,
368
- task_name=task_stablepy[task],
369
  vae_model=vae_model if vae_model != "None" else None,
370
- type_model_precision=torch.float16 if "flux" not in model_name.lower() else torch.bfloat16,
371
  retain_task_model_in_cache=False,
372
  )
373
- yield f"Model loaded: {model_name}"
374
 
375
  @spaces.GPU
376
  @torch.inference_mode()
@@ -487,38 +510,15 @@ class GuiSD:
487
  vae_msg = f"VAE: {vae_model}" if vae_model else ""
488
  msg_lora = []
489
 
 
 
490
  ## BEGIN MOD
491
  prompt, neg_prompt = insert_model_recom_prompt(prompt, neg_prompt, model_name)
492
  global lora_model_list
493
  lora_model_list = get_lora_model_list()
494
  ## END MOD
495
 
496
- if model_name in model_list:
497
- model_is_xl = "xl" in model_name.lower()
498
- sdxl_in_vae = vae_model and "sdxl" in vae_model.lower()
499
- model_type = "SDXL" if model_is_xl else "SD 1.5"
500
- incompatible_vae = (model_is_xl and vae_model and not sdxl_in_vae) or (not model_is_xl and sdxl_in_vae)
501
-
502
- if incompatible_vae:
503
- msg_inc_vae = (
504
- f"The selected VAE is for a { 'SD 1.5' if model_is_xl else 'SDXL' } model, but you"
505
- f" are using a { model_type } model. The default VAE "
506
- "will be used."
507
- )
508
- gr.Info(msg_inc_vae)
509
- vae_msg = msg_inc_vae
510
- vae_model = None
511
-
512
- for la in loras_list:
513
- if la is not None and la != "None" and la != "" and la in lora_model_list:
514
- print(la)
515
- lora_type = ("animetarot" in la.lower() or "Hyper-SD15-8steps".lower() in la.lower())
516
- if (model_is_xl and lora_type) or (not model_is_xl and not lora_type):
517
- msg_inc_lora = f"The LoRA {la} is for { 'SD 1.5' if model_is_xl else 'SDXL' }, but you are using { model_type }."
518
- gr.Info(msg_inc_lora)
519
- msg_lora.append(msg_inc_lora)
520
-
521
- task = task_stablepy[task]
522
 
523
  params_ip_img = []
524
  params_ip_msk = []
@@ -540,82 +540,53 @@ class GuiSD:
540
  params_ip_mode.append(modeip)
541
  params_ip_scale.append(scaleip)
542
 
543
- model_precision = torch.float16 if "flux" not in model_name.lower() else torch.bfloat16
544
-
545
- # First load
546
- model_precision = torch.float16
547
- if not self.model:
548
- print("Loading model...")
549
- self.model = Model_Diffusers(
550
- base_model_id=model_name,
551
- task_name=task,
552
- vae_model=vae_model if vae_model != "None" else None,
553
- type_model_precision=model_precision,
554
- retain_task_model_in_cache=retain_task_cache_gui,
555
- )
556
-
557
  if task != "txt2img" and not image_control:
558
  raise ValueError("No control image found: To use this function, you have to upload an image in 'Image ControlNet/Inpaint/Img2img'")
559
 
560
  if task == "inpaint" and not image_mask:
561
  raise ValueError("No mask image found: Specify one in 'Image Mask'")
562
 
563
- if upscaler_model_path in upscaler_keys[:9]:
564
  upscaler_model = upscaler_model_path
565
  else:
566
  directory_upscalers = 'upscalers'
567
  os.makedirs(directory_upscalers, exist_ok=True)
568
 
569
- url_upscaler = upscaler_dict_gui[upscaler_model_path]
570
 
571
  if not os.path.exists(f"./upscalers/{url_upscaler.split('/')[-1]}"):
572
- download_things(directory_upscalers, url_upscaler, hf_token)
573
 
574
  upscaler_model = f"./upscalers/{url_upscaler.split('/')[-1]}"
575
 
576
  logging.getLogger("ultralytics").setLevel(logging.INFO if adetailer_verbose else logging.ERROR)
577
 
578
- print("Config model:", model_name, vae_model, loras_list)
579
-
580
- self.model.load_pipe(
581
- model_name,
582
- task_name=task,
583
- vae_model=vae_model if vae_model != "None" else None,
584
- type_model_precision=model_precision,
585
- retain_task_model_in_cache=retain_task_cache_gui,
586
- )
587
-
588
- ## BEGIN MOD
589
- # if textual_inversion and self.model.class_name == "StableDiffusionXLPipeline":
590
- # print("No Textual inversion for SDXL")
591
- ## END MOD
592
-
593
  adetailer_params_A = {
594
- "face_detector_ad" : face_detector_ad_a,
595
- "person_detector_ad" : person_detector_ad_a,
596
- "hand_detector_ad" : hand_detector_ad_a,
597
  "prompt": prompt_ad_a,
598
- "negative_prompt" : negative_prompt_ad_a,
599
- "strength" : strength_ad_a,
600
  # "image_list_task" : None,
601
- "mask_dilation" : mask_dilation_a,
602
- "mask_blur" : mask_blur_a,
603
- "mask_padding" : mask_padding_a,
604
- "inpaint_only" : adetailer_inpaint_only,
605
- "sampler" : adetailer_sampler,
606
  }
607
 
608
  adetailer_params_B = {
609
- "face_detector_ad" : face_detector_ad_b,
610
- "person_detector_ad" : person_detector_ad_b,
611
- "hand_detector_ad" : hand_detector_ad_b,
612
  "prompt": prompt_ad_b,
613
- "negative_prompt" : negative_prompt_ad_b,
614
- "strength" : strength_ad_b,
615
  # "image_list_task" : None,
616
- "mask_dilation" : mask_dilation_b,
617
- "mask_blur" : mask_blur_b,
618
- "mask_padding" : mask_padding_b,
619
  }
620
  pipe_params = {
621
  "prompt": prompt,
@@ -708,8 +679,35 @@ class GuiSD:
708
  return self.infer_short(self.model, pipe_params, progress), info_state
709
  ## END MOD
710
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
711
 
712
  from pathlib import Path
 
713
  from modutils import (safe_float, escape_lora_basename, to_lora_key, to_lora_path,
714
  get_local_model_list, get_private_lora_model_lists, get_valid_lora_name,
715
  get_valid_lora_path, get_valid_lora_wt, get_lora_info,
@@ -749,11 +747,11 @@ def infer(prompt, negative_prompt, seed, randomize_seed, width, height, guidance
749
  lora5 = get_valid_lora_path(lora5)
750
  progress(1, desc="Preparation completed. Starting inference preparation...")
751
 
752
- sd_gen.load_new_model(model_name, vae, task_model_list[0], progress)
753
  images, info = sd_gen.generate_pipeline(prompt, negative_prompt, 1, num_inference_steps,
754
  guidance_scale, True, generator, lora1, lora1_wt, lora2, lora2_wt, lora3, lora3_wt,
755
  lora4, lora4_wt, lora5, lora5_wt, sampler,
756
- height, width, model_name, vae, task_model_list[0], None, "Canny", 512, 1024,
757
  None, None, None, 0.35, 100, 200, 0.1, 0.1, 1.0, 0., 1., False, "Classic", None,
758
  1.0, 100, 10, 30, 0.55, "Use same sampler", "", "",
759
  False, True, 1, True, False, False, False, False, "./images", False, False, False, True, 1, 0.55,
 
1
  import spaces
2
  import os
3
  from stablepy import Model_Diffusers
 
4
  from stablepy.diffusers_vanilla.style_prompt_config import STYLE_NAMES
5
+ from stablepy.diffusers_vanilla.constants import FLUX_CN_UNION_MODES
6
  import torch
7
  import re
8
+ from huggingface_hub import HfApi
 
9
  from stablepy import (
10
  CONTROLNET_MODEL_IDS,
11
  VALID_TASKS,
 
21
  SD15_TASKS,
22
  SDXL_TASKS,
23
  )
24
+ #import urllib.parse
25
  import gradio as gr
26
  from PIL import Image
27
  import IPython.display
 
39
  logger.setLevel(logging.CRITICAL)
40
 
41
  from env import (
42
+ HF_TOKEN, hf_read_token, # to use only for private repos
43
  CIVITAI_API_KEY, HF_LORA_PRIVATE_REPOS1, HF_LORA_PRIVATE_REPOS2,
44
  HF_LORA_ESSENTIAL_PRIVATE_REPO, HF_VAE_PRIVATE_REPO,
45
  HF_SDXL_EMBEDS_NEGATIVE_PRIVATE_REPO, HF_SDXL_EMBEDS_POSITIVE_PRIVATE_REPO,
 
48
  load_diffusers_format_model, download_model_list, download_lora_list,
49
  download_vae_list, download_embeds)
50
 
51
+ PREPROCESSOR_CONTROLNET = {
52
  "openpose": [
53
  "Openpose",
54
  "None",
 
120
  ],
121
  }
122
 
123
+ TASK_STABLEPY = {
124
  'txt2img': 'txt2img',
125
  'img2img': 'img2img',
126
  'inpaint': 'inpaint',
 
146
  'tile ControlNet': 'tile',
147
  }
148
 
149
+ TASK_MODEL_LIST = list(TASK_STABLEPY.keys())
150
+
151
+ UPSCALER_DICT_GUI = {
152
+ None: None,
153
+ "Lanczos": "Lanczos",
154
+ "Nearest": "Nearest",
155
+ 'Latent': 'Latent',
156
+ 'Latent (antialiased)': 'Latent (antialiased)',
157
+ 'Latent (bicubic)': 'Latent (bicubic)',
158
+ 'Latent (bicubic antialiased)': 'Latent (bicubic antialiased)',
159
+ 'Latent (nearest)': 'Latent (nearest)',
160
+ 'Latent (nearest-exact)': 'Latent (nearest-exact)',
161
+ "RealESRGAN_x4plus": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
162
+ "RealESRNet_x4plus": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth",
163
+ "RealESRGAN_x4plus_anime_6B": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
164
+ "RealESRGAN_x2plus": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
165
+ "realesr-animevideov3": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth",
166
+ "realesr-general-x4v3": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth",
167
+ "realesr-general-wdn-x4v3": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth",
168
+ "4x-UltraSharp": "https://huggingface.co/Shandypur/ESRGAN-4x-UltraSharp/resolve/main/4x-UltraSharp.pth",
169
+ "4x_foolhardy_Remacri": "https://huggingface.co/FacehugmanIII/4x_foolhardy_Remacri/resolve/main/4x_foolhardy_Remacri.pth",
170
+ "Remacri4xExtraSmoother": "https://huggingface.co/hollowstrawberry/upscalers-backup/resolve/main/ESRGAN/Remacri%204x%20ExtraSmoother.pth",
171
+ "AnimeSharp4x": "https://huggingface.co/hollowstrawberry/upscalers-backup/resolve/main/ESRGAN/AnimeSharp%204x.pth",
172
+ "lollypop": "https://huggingface.co/hollowstrawberry/upscalers-backup/resolve/main/ESRGAN/lollypop.pth",
173
+ "RealisticRescaler4x": "https://huggingface.co/hollowstrawberry/upscalers-backup/resolve/main/ESRGAN/RealisticRescaler%204x.pth",
174
+ "NickelbackFS4x": "https://huggingface.co/hollowstrawberry/upscalers-backup/resolve/main/ESRGAN/NickelbackFS%204x.pth"
175
+ }
176
+
177
+ UPSCALER_KEYS = list(UPSCALER_DICT_GUI.keys())
178
 
179
  def download_things(directory, url, hf_token="", civitai_api_key=""):
180
  url = url.strip()
 
205
  else:
206
  os.system(f"aria2c --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 -d {directory} {url}")
207
 
 
208
  def get_model_list(directory_path):
209
  model_list = []
210
  valid_extensions = {'.ckpt', '.pt', '.pth', '.safetensors', '.bin'}
211
 
212
  for filename in os.listdir(directory_path):
213
  if os.path.splitext(filename)[1] in valid_extensions:
214
+ # name_without_extension = os.path.splitext(filename)[0]
215
  file_path = os.path.join(directory_path, filename)
216
  # model_list.append((name_without_extension, file_path))
217
  model_list.append(file_path)
218
  print('\033[34mFILE: ' + file_path + '\033[0m')
219
  return model_list
220
 
 
221
  ## BEGIN MOD
222
  from modutils import (to_list, list_uniq, list_sub, get_model_id_list, get_tupled_embed_list,
223
  get_tupled_model_list, get_lora_model_list, download_private_repo)
 
235
  load_diffusers_format_model = list_uniq(load_diffusers_format_model + get_model_id_list())
236
  ## END MOD
237
 
 
 
 
238
  # Download stuffs
239
  for url in [url.strip() for url in download_model.split(',')]:
240
  if not os.path.exists(f"./models/{url.split('/')[-1]}"):
241
+ download_things(directory_models, url, HF_TOKEN, CIVITAI_API_KEY)
242
  for url in [url.strip() for url in download_vae.split(',')]:
243
  if not os.path.exists(f"./vaes/{url.split('/')[-1]}"):
244
+ download_things(directory_vaes, url, HF_TOKEN, CIVITAI_API_KEY)
245
  for url in [url.strip() for url in download_lora.split(',')]:
246
  if not os.path.exists(f"./loras/{url.split('/')[-1]}"):
247
+ download_things(directory_loras, url, HF_TOKEN, CIVITAI_API_KEY)
248
 
249
  # Download Embeddings
250
  for url_embed in download_embeds:
251
  if not os.path.exists(f"./embedings/{url_embed.split('/')[-1]}"):
252
+ download_things(directory_embeds, url_embed, HF_TOKEN, CIVITAI_API_KEY)
253
 
254
  # Build list models
255
  embed_list = get_model_list(directory_embeds)
 
266
 
267
  def get_embed_list(pipeline_name):
268
  return get_tupled_embed_list(embed_sdxl_list if pipeline_name == "StableDiffusionXLPipeline" else embed_list)
 
 
269
  ## END MOD
270
 
271
  print('\033[33m🏁 Download and listing of valid models completed.\033[0m')
272
 
273
+ msg_inc_vae = (
274
+ "Use the right VAE for your model to maintain image quality. The wrong"
275
+ " VAE can lead to poor results, like blurriness in the generated images."
276
+ )
277
+
278
+ SDXL_TASK = [k for k, v in TASK_STABLEPY.items() if v in SDXL_TASKS]
279
+ SD_TASK = [k for k, v in TASK_STABLEPY.items() if v in SD15_TASKS]
280
+ FLUX_TASK = list(TASK_STABLEPY.keys())[:3] + [k for k, v in TASK_STABLEPY.items() if v in FLUX_CN_UNION_MODES.keys()]
281
+
282
+ MODEL_TYPE_TASK = {
283
+ "SD 1.5": SD_TASK,
284
+ "SDXL": SDXL_TASK,
285
+ "FLUX": FLUX_TASK,
 
 
 
 
 
 
 
 
 
 
 
286
  }
287
 
288
+ MODEL_TYPE_CLASS = {
289
+ "diffusers:StableDiffusionPipeline": "SD 1.5",
290
+ "diffusers:StableDiffusionXLPipeline": "SDXL",
291
+ "diffusers:FluxPipeline": "FLUX",
292
+ }
293
 
294
+ POST_PROCESSING_SAMPLER = ["Use same sampler"] + scheduler_names[:-2]
295
 
296
  def extract_parameters(input_string):
297
  parameters = {}
298
  input_string = input_string.replace("\n", "")
299
 
300
+ if "Negative prompt:" not in input_string:
301
  print("Negative prompt not detected")
302
  parameters["prompt"] = input_string
303
  return parameters
304
 
305
  parm = input_string.split("Negative prompt:")
306
  parameters["prompt"] = parm[0]
307
+ if "Steps:" not in parm[1]:
308
  print("Steps not detected")
309
  parameters["neg_prompt"] = parm[1]
310
  return parameters
 
332
 
333
  return parameters
334
 
335
+ def get_model_type(repo_id: str):
336
+ api = HfApi(token=os.environ.get("HF_TOKEN")) # if use private or gated model
337
+ default = "SD 1.5"
338
+ try:
339
+ model = api.model_info(repo_id=repo_id, timeout=5.0)
340
+ tags = model.tags
341
+ for tag in tags:
342
+ if tag in MODEL_TYPE_CLASS.keys(): return MODEL_TYPE_CLASS.get(tag, default)
343
+ except Exception:
344
+ return default
345
+ return default
346
 
347
  ## BEGIN MOD
348
  class GuiSD:
 
373
 
374
  def load_new_model(self, model_name, vae_model, task, progress=gr.Progress(track_tqdm=True)):
375
 
376
+ #yield f"Loading model: {model_name}"
377
 
378
  vae_model = vae_model if vae_model != "None" else None
379
+ model_type = get_model_type(model_name)
380
 
381
+ if vae_model:
382
+ vae_type = "SDXL" if "sdxl" in vae_model.lower() else "SD 1.5"
383
+ if model_type != vae_type:
384
+ gr.Info(msg_inc_vae)
 
 
 
 
385
 
386
  self.model.device = torch.device("cpu")
387
+ dtype_model = torch.bfloat16 if model_type == "FLUX" else torch.float16
388
 
389
  self.model.load_pipe(
390
  model_name,
391
+ task_name=TASK_STABLEPY[task],
392
  vae_model=vae_model if vae_model != "None" else None,
393
+ type_model_precision=dtype_model,
394
  retain_task_model_in_cache=False,
395
  )
396
+ #yield f"Model loaded: {model_name}"
397
 
398
  @spaces.GPU
399
  @torch.inference_mode()
 
510
  vae_msg = f"VAE: {vae_model}" if vae_model else ""
511
  msg_lora = []
512
 
513
+ print("Config model:", model_name, vae_model, loras_list)
514
+
515
  ## BEGIN MOD
516
  prompt, neg_prompt = insert_model_recom_prompt(prompt, neg_prompt, model_name)
517
  global lora_model_list
518
  lora_model_list = get_lora_model_list()
519
  ## END MOD
520
 
521
+ task = TASK_STABLEPY[task]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
522
 
523
  params_ip_img = []
524
  params_ip_msk = []
 
540
  params_ip_mode.append(modeip)
541
  params_ip_scale.append(scaleip)
542
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
543
  if task != "txt2img" and not image_control:
544
  raise ValueError("No control image found: To use this function, you have to upload an image in 'Image ControlNet/Inpaint/Img2img'")
545
 
546
  if task == "inpaint" and not image_mask:
547
  raise ValueError("No mask image found: Specify one in 'Image Mask'")
548
 
549
+ if upscaler_model_path in UPSCALER_KEYS[:9]:
550
  upscaler_model = upscaler_model_path
551
  else:
552
  directory_upscalers = 'upscalers'
553
  os.makedirs(directory_upscalers, exist_ok=True)
554
 
555
+ url_upscaler = UPSCALER_DICT_GUI[upscaler_model_path]
556
 
557
  if not os.path.exists(f"./upscalers/{url_upscaler.split('/')[-1]}"):
558
+ download_things(directory_upscalers, url_upscaler, HF_TOKEN)
559
 
560
  upscaler_model = f"./upscalers/{url_upscaler.split('/')[-1]}"
561
 
562
  logging.getLogger("ultralytics").setLevel(logging.INFO if adetailer_verbose else logging.ERROR)
563
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
564
  adetailer_params_A = {
565
+ "face_detector_ad": face_detector_ad_a,
566
+ "person_detector_ad": person_detector_ad_a,
567
+ "hand_detector_ad": hand_detector_ad_a,
568
  "prompt": prompt_ad_a,
569
+ "negative_prompt": negative_prompt_ad_a,
570
+ "strength": strength_ad_a,
571
  # "image_list_task" : None,
572
+ "mask_dilation": mask_dilation_a,
573
+ "mask_blur": mask_blur_a,
574
+ "mask_padding": mask_padding_a,
575
+ "inpaint_only": adetailer_inpaint_only,
576
+ "sampler": adetailer_sampler,
577
  }
578
 
579
  adetailer_params_B = {
580
+ "face_detector_ad": face_detector_ad_b,
581
+ "person_detector_ad": person_detector_ad_b,
582
+ "hand_detector_ad": hand_detector_ad_b,
583
  "prompt": prompt_ad_b,
584
+ "negative_prompt": negative_prompt_ad_b,
585
+ "strength": strength_ad_b,
586
  # "image_list_task" : None,
587
+ "mask_dilation": mask_dilation_b,
588
+ "mask_blur": mask_blur_b,
589
+ "mask_padding": mask_padding_b,
590
  }
591
  pipe_params = {
592
  "prompt": prompt,
 
679
  return self.infer_short(self.model, pipe_params, progress), info_state
680
  ## END MOD
681
 
682
+ # def sd_gen_generate_pipeline(*args):
683
+
684
+ # # Load lora in CPU
685
+ # status_lora = sd_gen.model.lora_merge(
686
+ # lora_A=args[7] if args[7] != "None" else None, lora_scale_A=args[8],
687
+ # lora_B=args[9] if args[9] != "None" else None, lora_scale_B=args[10],
688
+ # lora_C=args[11] if args[11] != "None" else None, lora_scale_C=args[12],
689
+ # lora_D=args[13] if args[13] != "None" else None, lora_scale_D=args[14],
690
+ # lora_E=args[15] if args[15] != "None" else None, lora_scale_E=args[16],
691
+ # )
692
+
693
+ # lora_list = [args[7], args[9], args[11], args[13], args[15]]
694
+ # print(status_lora)
695
+ # for status, lora in zip(status_lora, lora_list):
696
+ # if status:
697
+ # gr.Info(f"LoRA loaded: {lora}")
698
+ # elif status is not None:
699
+ # gr.Warning(f"Failed to load LoRA: {lora}")
700
+
701
+ # # if status_lora == [None] * 5 and self.model.lora_memory != [None] * 5:
702
+ # # gr.Info(f"LoRAs in cache: {", ".join(str(x) for x in self.model.lora_memory if x is not None)}")
703
+
704
+ # yield from sd_gen.generate_pipeline(*args)
705
+
706
+
707
+ # sd_gen_generate_pipeline.zerogpu = True
708
 
709
  from pathlib import Path
710
+ import random
711
  from modutils import (safe_float, escape_lora_basename, to_lora_key, to_lora_path,
712
  get_local_model_list, get_private_lora_model_lists, get_valid_lora_name,
713
  get_valid_lora_path, get_valid_lora_wt, get_lora_info,
 
747
  lora5 = get_valid_lora_path(lora5)
748
  progress(1, desc="Preparation completed. Starting inference preparation...")
749
 
750
+ sd_gen.load_new_model(model_name, vae, TASK_MODEL_LIST[0], progress)
751
  images, info = sd_gen.generate_pipeline(prompt, negative_prompt, 1, num_inference_steps,
752
  guidance_scale, True, generator, lora1, lora1_wt, lora2, lora2_wt, lora3, lora3_wt,
753
  lora4, lora4_wt, lora5, lora5_wt, sampler,
754
+ height, width, model_name, vae, TASK_MODEL_LIST[0], None, "Canny", 512, 1024,
755
  None, None, None, 0.35, 100, 200, 0.1, 0.1, 1.0, 0., 1., False, "Classic", None,
756
  1.0, 100, 10, 30, 0.55, "Use same sampler", "", "",
757
  False, True, 1, True, False, False, False, False, "./images", False, False, False, True, 1, 0.55,
env.py CHANGED
@@ -1,7 +1,7 @@
1
  import os
2
 
3
  CIVITAI_API_KEY = os.environ.get("CIVITAI_API_KEY")
4
- hf_token = os.environ.get("HF_TOKEN")
5
  hf_read_token = os.environ.get('HF_READ_TOKEN') # only use for private repo
6
 
7
  # - **List Models**
 
1
  import os
2
 
3
  CIVITAI_API_KEY = os.environ.get("CIVITAI_API_KEY")
4
+ HF_TOKEN = os.environ.get("HF_TOKEN")
5
  hf_read_token = os.environ.get('HF_READ_TOKEN') # only use for private repo
6
 
7
  # - **List Models**
modutils.py CHANGED
@@ -8,7 +8,7 @@ from pathlib import Path
8
 
9
  from env import (HF_LORA_PRIVATE_REPOS1, HF_LORA_PRIVATE_REPOS2,
10
  HF_MODEL_USER_EX, HF_MODEL_USER_LIKES,
11
- directory_loras, hf_read_token, hf_token, CIVITAI_API_KEY)
12
 
13
 
14
  def get_user_agent():
@@ -227,11 +227,16 @@ def get_model_id_list():
227
  model_ids.append(model.id) if not model.private else ""
228
  anime_models = []
229
  real_models = []
 
 
230
  for model in models_ex:
231
- if not model.private and not model.gated and "diffusers:FluxPipeline" not in model.tags:
232
- anime_models.append(model.id) if "anime" in model.tags else real_models.append(model.id)
 
233
  model_ids.extend(anime_models)
234
  model_ids.extend(real_models)
 
 
235
  model_id_list = model_ids.copy()
236
  return model_ids
237
 
@@ -426,7 +431,7 @@ def download_lora(dl_urls: str):
426
  for url in [url.strip() for url in dl_urls.split(',')]:
427
  local_path = f"{directory_loras}/{url.split('/')[-1]}"
428
  if not Path(local_path).exists():
429
- download_things(directory_loras, url, hf_token, CIVITAI_API_KEY)
430
  urls.append(url)
431
  after = get_local_model_list(directory_loras)
432
  new_files = list_sub(after, before)
@@ -688,7 +693,7 @@ def get_my_lora(link_url):
688
  before = get_local_model_list(directory_loras)
689
  for url in [url.strip() for url in link_url.split(',')]:
690
  if not Path(f"{directory_loras}/{url.split('/')[-1]}").exists():
691
- download_things(directory_loras, url, hf_token, CIVITAI_API_KEY)
692
  after = get_local_model_list(directory_loras)
693
  new_files = list_sub(after, before)
694
  for file in new_files:
@@ -745,8 +750,7 @@ def move_file_lora(filepaths):
745
 
746
 
747
  def get_civitai_info(path):
748
- global civitai_not_exists_list
749
- global loras_url_to_path_dict
750
  import requests
751
  from requests.adapters import HTTPAdapter
752
  from urllib3.util import Retry
@@ -1242,7 +1246,7 @@ def get_model_pipeline(repo_id: str):
1242
  try:
1243
  if " " in repo_id or not api.repo_exists(repo_id): return default
1244
  model = api.model_info(repo_id=repo_id)
1245
- except Exception as e:
1246
  return default
1247
  if model.private or model.gated: return default
1248
  tags = model.tags
 
8
 
9
  from env import (HF_LORA_PRIVATE_REPOS1, HF_LORA_PRIVATE_REPOS2,
10
  HF_MODEL_USER_EX, HF_MODEL_USER_LIKES,
11
+ directory_loras, hf_read_token, HF_TOKEN, CIVITAI_API_KEY)
12
 
13
 
14
  def get_user_agent():
 
227
  model_ids.append(model.id) if not model.private else ""
228
  anime_models = []
229
  real_models = []
230
+ anime_models_flux = []
231
+ real_models_flux = []
232
  for model in models_ex:
233
+ if not model.private and not model.gated:
234
+ if "diffusers:FluxPipeline" in model.tags: anime_models_flux.append(model.id) if "anime" in model.tags else real_models_flux.append(model.id)
235
+ else: anime_models.append(model.id) if "anime" in model.tags else real_models.append(model.id)
236
  model_ids.extend(anime_models)
237
  model_ids.extend(real_models)
238
+ model_ids.extend(anime_models_flux)
239
+ model_ids.extend(real_models_flux)
240
  model_id_list = model_ids.copy()
241
  return model_ids
242
 
 
431
  for url in [url.strip() for url in dl_urls.split(',')]:
432
  local_path = f"{directory_loras}/{url.split('/')[-1]}"
433
  if not Path(local_path).exists():
434
+ download_things(directory_loras, url, HF_TOKEN, CIVITAI_API_KEY)
435
  urls.append(url)
436
  after = get_local_model_list(directory_loras)
437
  new_files = list_sub(after, before)
 
693
  before = get_local_model_list(directory_loras)
694
  for url in [url.strip() for url in link_url.split(',')]:
695
  if not Path(f"{directory_loras}/{url.split('/')[-1]}").exists():
696
+ download_things(directory_loras, url, HF_TOKEN, CIVITAI_API_KEY)
697
  after = get_local_model_list(directory_loras)
698
  new_files = list_sub(after, before)
699
  for file in new_files:
 
750
 
751
 
752
  def get_civitai_info(path):
753
+ global civitai_not_exists_list, loras_url_to_path_dict
 
754
  import requests
755
  from requests.adapters import HTTPAdapter
756
  from urllib3.util import Retry
 
1246
  try:
1247
  if " " in repo_id or not api.repo_exists(repo_id): return default
1248
  model = api.model_info(repo_id=repo_id)
1249
+ except Exception:
1250
  return default
1251
  if model.private or model.gated: return default
1252
  tags = model.tags