John6666 commited on
Commit
0f73bfb
·
verified ·
1 Parent(s): abff629

Upload dc.py

Browse files
Files changed (1) hide show
  1. dc.py +330 -214
dc.py CHANGED
@@ -1,52 +1,33 @@
1
  import spaces
2
  import os
3
  from stablepy import Model_Diffusers
4
- from constants import (
5
- PREPROCESSOR_CONTROLNET,
6
- TASK_STABLEPY,
7
- TASK_MODEL_LIST,
8
- UPSCALER_DICT_GUI,
9
- UPSCALER_KEYS,
10
- PROMPT_W_OPTIONS,
11
- WARNING_MSG_VAE,
12
- SDXL_TASK,
13
- MODEL_TYPE_TASK,
14
- POST_PROCESSING_SAMPLER,
15
-
16
- )
17
  from stablepy.diffusers_vanilla.style_prompt_config import STYLE_NAMES
 
18
  import torch
19
  import re
 
20
  from stablepy import (
 
 
 
 
 
21
  scheduler_names,
 
22
  IP_ADAPTERS_SD,
23
  IP_ADAPTERS_SDXL,
 
 
 
 
24
  )
25
  import time
26
  from PIL import ImageFile
27
- from utils import (
28
- get_model_list,
29
- extract_parameters,
30
- get_model_type,
31
- extract_exif_data,
32
- create_mask_now,
33
- download_diffuser_repo,
34
- progress_step_bar,
35
- html_template_message,
36
- )
37
- from datetime import datetime
38
- import gradio as gr
39
- import logging
40
- import diffusers
41
- import warnings
42
- from stablepy import logger
43
- # import urllib.parse
44
 
45
  ImageFile.LOAD_TRUNCATED_IMAGES = True
46
- # os.environ["PYTORCH_NO_CUDA_MEMORY_CACHING"] = "1"
47
  print(os.getenv("SPACES_ZERO_GPU"))
48
 
49
- ## BEGIN MOD
50
  import gradio as gr
51
  import logging
52
  logging.getLogger("diffusers").setLevel(logging.ERROR)
@@ -57,63 +38,205 @@ warnings.filterwarnings(action="ignore", category=FutureWarning, module="diffuse
57
  warnings.filterwarnings(action="ignore", category=UserWarning, module="diffusers")
58
  warnings.filterwarnings(action="ignore", category=FutureWarning, module="transformers")
59
  from stablepy import logger
60
- logger.setLevel(logging.DEBUG)
61
 
62
  from env import (
63
- HF_TOKEN, HF_READ_TOKEN, # to use only for private repos
64
  CIVITAI_API_KEY, HF_LORA_PRIVATE_REPOS1, HF_LORA_PRIVATE_REPOS2,
65
  HF_LORA_ESSENTIAL_PRIVATE_REPO, HF_VAE_PRIVATE_REPO,
66
  HF_SDXL_EMBEDS_NEGATIVE_PRIVATE_REPO, HF_SDXL_EMBEDS_POSITIVE_PRIVATE_REPO,
67
- DIRECTORY_MODELS, DIRECTORY_LORAS, DIRECTORY_VAES, DIRECTORY_EMBEDS,
68
- DIRECTORY_EMBEDS_SDXL, DIRECTORY_EMBEDS_POSITIVE_SDXL,
69
- LOAD_DIFFUSERS_FORMAT_MODEL, DOWNLOAD_MODEL_LIST, DOWNLOAD_LORA_LIST,
70
- DOWNLOAD_VAE_LIST, DOWNLOAD_EMBEDS)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
 
72
  from modutils import (to_list, list_uniq, list_sub, get_model_id_list, get_tupled_embed_list,
73
  get_tupled_model_list, get_lora_model_list, download_private_repo, download_things)
74
 
75
  # - **Download Models**
76
- download_model = ", ".join(DOWNLOAD_MODEL_LIST)
77
  # - **Download VAEs**
78
- download_vae = ", ".join(DOWNLOAD_VAE_LIST)
79
  # - **Download LoRAs**
80
- download_lora = ", ".join(DOWNLOAD_LORA_LIST)
81
 
82
- #download_private_repo(HF_LORA_ESSENTIAL_PRIVATE_REPO, DIRECTORY_LORAS, True)
83
- download_private_repo(HF_VAE_PRIVATE_REPO, DIRECTORY_VAES, False)
84
 
85
- load_diffusers_format_model = list_uniq(LOAD_DIFFUSERS_FORMAT_MODEL + get_model_id_list())
86
  ## END MOD
87
 
88
  # Download stuffs
89
  for url in [url.strip() for url in download_model.split(',')]:
90
  if not os.path.exists(f"./models/{url.split('/')[-1]}"):
91
- download_things(DIRECTORY_MODELS, url, HF_TOKEN, CIVITAI_API_KEY)
92
  for url in [url.strip() for url in download_vae.split(',')]:
93
  if not os.path.exists(f"./vaes/{url.split('/')[-1]}"):
94
- download_things(DIRECTORY_VAES, url, HF_TOKEN, CIVITAI_API_KEY)
95
  for url in [url.strip() for url in download_lora.split(',')]:
96
  if not os.path.exists(f"./loras/{url.split('/')[-1]}"):
97
- download_things(DIRECTORY_LORAS, url, HF_TOKEN, CIVITAI_API_KEY)
98
 
99
  # Download Embeddings
100
- for url_embed in DOWNLOAD_EMBEDS:
101
  if not os.path.exists(f"./embedings/{url_embed.split('/')[-1]}"):
102
- download_things(DIRECTORY_EMBEDS, url_embed, HF_TOKEN, CIVITAI_API_KEY)
103
 
104
  # Build list models
105
- embed_list = get_model_list(DIRECTORY_EMBEDS)
106
- model_list = get_model_list(DIRECTORY_MODELS)
107
  model_list = load_diffusers_format_model + model_list
108
-
109
  ## BEGIN MOD
110
  lora_model_list = get_lora_model_list()
111
- vae_model_list = get_model_list(DIRECTORY_VAES)
112
  vae_model_list.insert(0, "None")
113
 
114
- #download_private_repo(HF_SDXL_EMBEDS_NEGATIVE_PRIVATE_REPO, DIRECTORY_EMBEDS_SDXL, False)
115
- #download_private_repo(HF_SDXL_EMBEDS_POSITIVE_PRIVATE_REPO, DIRECTORY_EMBEDS_POSITIVE_SDXL, False)
116
- embed_sdxl_list = get_model_list(DIRECTORY_EMBEDS_SDXL) + get_model_list(DIRECTORY_EMBEDS_POSITIVE_SDXL)
117
 
118
  def get_embed_list(pipeline_name):
119
  return get_tupled_embed_list(embed_sdxl_list if pipeline_name == "StableDiffusionXLPipeline" else embed_list)
@@ -121,13 +244,99 @@ def get_embed_list(pipeline_name):
121
 
122
  print('\033[33m🏁 Download and listing of valid models completed.\033[0m')
123
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
  ## BEGIN MOD
125
  class GuiSD:
126
- def __init__(self, stream=False):
127
  self.model = None
128
- self.status_loading = False
129
- self.sleep_loading = 4
130
- self.last_load = datetime.now()
 
 
 
 
 
 
 
 
 
131
 
132
  def infer_short(self, model, pipe_params, progress=gr.Progress(track_tqdm=True)):
133
  #progress(0, desc="Start inference...")
@@ -141,86 +350,31 @@ class GuiSD:
141
  return img
142
 
143
  def load_new_model(self, model_name, vae_model, task, progress=gr.Progress(track_tqdm=True)):
144
- vae_model = vae_model if vae_model != "None" else None
145
- model_type = get_model_type(model_name)
146
- dtype_model = torch.bfloat16 if model_type == "FLUX" else torch.float16
147
-
148
- if not os.path.exists(model_name):
149
- _ = download_diffuser_repo(
150
- repo_name=model_name,
151
- model_type=model_type,
152
- revision="main",
153
- token=True,
154
- )
155
-
156
- for i in range(68):
157
- if not self.status_loading:
158
- self.status_loading = True
159
- if i > 0:
160
- time.sleep(self.sleep_loading)
161
- print("Previous model ops...")
162
- break
163
- time.sleep(0.5)
164
- print(f"Waiting queue {i}")
165
- #yield "Waiting queue"
166
-
167
- self.status_loading = True
168
 
169
  #yield f"Loading model: {model_name}"
 
 
 
170
 
171
  if vae_model:
172
  vae_type = "SDXL" if "sdxl" in vae_model.lower() else "SD 1.5"
173
  if model_type != vae_type:
174
- gr.Warning(WARNING_MSG_VAE)
175
-
176
- print("Loading model...")
177
-
178
- try:
179
- start_time = time.time()
180
-
181
- if self.model is None:
182
- self.model = Model_Diffusers(
183
- base_model_id=model_name,
184
- task_name=TASK_STABLEPY[task],
185
- vae_model=vae_model,
186
- type_model_precision=dtype_model,
187
- retain_task_model_in_cache=False,
188
- device="cpu",
189
- )
190
- else:
191
-
192
- if self.model.base_model_id != model_name:
193
- load_now_time = datetime.now()
194
- elapsed_time = (load_now_time - self.last_load).total_seconds()
195
-
196
- if elapsed_time <= 8:
197
- print("Waiting for the previous model's time ops...")
198
- time.sleep(8-elapsed_time)
199
-
200
- self.model.device = torch.device("cpu")
201
- self.model.load_pipe(
202
- model_name,
203
- task_name=TASK_STABLEPY[task],
204
- vae_model=vae_model,
205
- type_model_precision=dtype_model,
206
- retain_task_model_in_cache=False,
207
- )
208
-
209
- end_time = time.time()
210
- self.sleep_loading = max(min(int(end_time - start_time), 10), 4)
211
- except Exception as e:
212
- self.last_load = datetime.now()
213
- self.status_loading = False
214
- self.sleep_loading = 4
215
- raise e
216
-
217
- self.last_load = datetime.now()
218
- self.status_loading = False
219
 
 
 
 
 
 
 
 
 
 
 
220
  #yield f"Model loaded: {model_name}"
221
 
222
  #@spaces.GPU
223
- #@torch.inference_mode()
224
  def generate_pipeline(
225
  self,
226
  prompt,
@@ -325,24 +479,23 @@ class GuiSD:
325
  mode_ip2,
326
  scale_ip2,
327
  pag_scale,
 
328
  ):
329
- info_state = html_template_message("Navigating latent space...")
330
- #yield info_state, gr.update(), gr.update()
331
-
332
  vae_model = vae_model if vae_model != "None" else None
333
  loras_list = [lora1, lora2, lora3, lora4, lora5]
334
  vae_msg = f"VAE: {vae_model}" if vae_model else ""
335
  msg_lora = ""
336
 
 
 
337
  ## BEGIN MOD
338
- loras_list = [s if s else "None" for s in loras_list]
339
  prompt, neg_prompt = insert_model_recom_prompt(prompt, neg_prompt, model_name)
340
  global lora_model_list
341
  lora_model_list = get_lora_model_list()
342
  ## END MOD
343
 
344
- print("Config model:", model_name, vae_model, loras_list)
345
-
346
  task = TASK_STABLEPY[task]
347
 
348
  params_ip_img = []
@@ -365,9 +518,6 @@ class GuiSD:
365
  params_ip_mode.append(modeip)
366
  params_ip_scale.append(scaleip)
367
 
368
- concurrency = 5
369
- self.model.stream_config(concurrency=concurrency, latent_resize_by=1, vae_decoding=False)
370
-
371
  if task != "txt2img" and not image_control:
372
  raise ValueError("No control image found: To use this function, you have to upload an image in 'Image ControlNet/Inpaint/Img2img'")
373
 
@@ -439,15 +589,15 @@ class GuiSD:
439
  "high_threshold": high_threshold,
440
  "value_threshold": value_threshold,
441
  "distance_threshold": distance_threshold,
442
- "lora_A": lora1 if lora1 != "None" else None,
443
  "lora_scale_A": lora_scale1,
444
- "lora_B": lora2 if lora2 != "None" else None,
445
  "lora_scale_B": lora_scale2,
446
- "lora_C": lora3 if lora3 != "None" else None,
447
  "lora_scale_C": lora_scale3,
448
- "lora_D": lora4 if lora4 != "None" else None,
449
  "lora_scale_D": lora_scale4,
450
- "lora_E": lora5 if lora5 != "None" else None,
451
  "lora_scale_E": lora_scale5,
452
  ## BEGIN MOD
453
  "textual_inversion": get_embed_list(self.model.class_name) if textual_inversion else [],
@@ -497,59 +647,18 @@ class GuiSD:
497
  }
498
 
499
  self.model.device = torch.device("cuda:0")
500
- if hasattr(self.model.pipe, "transformer") and loras_list != ["None"] * 5:
501
  self.model.pipe.transformer.to(self.model.device)
502
  print("transformer to cuda")
503
 
504
- #return self.infer_short(self.model, pipe_params), info_state
505
-
506
- actual_progress = 0
507
- info_images = gr.update()
508
- for img, seed, image_path, metadata in self.model(**pipe_params):
509
- info_state = progress_step_bar(actual_progress, steps)
510
- actual_progress += concurrency
511
- if image_path:
512
- info_images = f"Seeds: {str(seed)}"
513
- if vae_msg:
514
- info_images = info_images + "<br>" + vae_msg
515
-
516
- if "Cannot copy out of meta tensor; no data!" in self.model.last_lora_error:
517
- msg_ram = "Unable to process the LoRAs due to high RAM usage; please try again later."
518
- print(msg_ram)
519
- msg_lora += f"<br>{msg_ram}"
520
-
521
- for status, lora in zip(self.model.lora_status, self.model.lora_memory):
522
- if status:
523
- msg_lora += f"<br>Loaded: {lora}"
524
- elif status is not None:
525
- msg_lora += f"<br>Error with: {lora}"
526
-
527
- if msg_lora:
528
- info_images += msg_lora
529
-
530
- info_images = info_images + "<br>" + "GENERATION DATA:<br>" + metadata[0].replace("\n", "<br>") + "<br>-------<br>"
531
-
532
- download_links = "<br>".join(
533
- [
534
- f'<a href="{path.replace("/images/", "/file=/home/user/app/images/")}" download="{os.path.basename(path)}">Download Image {i + 1}</a>'
535
- for i, path in enumerate(image_path)
536
- ]
537
- )
538
- if save_generated_images:
539
- info_images += f"<br>{download_links}"
540
 
541
- ## BEGIN MOD
542
- img = save_images(img, metadata)
543
  ## END MOD
544
 
545
- info_state = "COMPLETE"
546
-
547
- #yield info_state, img, info_images
548
- return info_state, img, info_images
549
-
550
  def dynamic_gpu_duration(func, duration, *args):
551
 
552
- @torch.inference_mode()
553
  @spaces.GPU(duration=duration)
554
  def wrapped_func():
555
  return func(*args)
@@ -569,7 +678,7 @@ def sd_gen_generate_pipeline(*args):
569
  load_lora_cpu = args[-3]
570
  generation_args = args[:-3]
571
  lora_list = [
572
- None if item == "None" or item == "" else item # MOD
573
  for item in [args[7], args[9], args[11], args[13], args[15]]
574
  ]
575
  lora_status = [None] * 5
@@ -579,7 +688,7 @@ def sd_gen_generate_pipeline(*args):
579
  msg_load_lora = "Updating LoRAs in CPU (Slow but saves GPU usage)..."
580
 
581
  #if lora_list != sd_gen.model.lora_memory and lora_list != [None] * 5:
582
- # yield msg_load_lora, gr.update(), gr.update()
583
 
584
  # Load lora in CPU
585
  if load_lora_cpu:
@@ -605,16 +714,14 @@ def sd_gen_generate_pipeline(*args):
605
  )
606
  gr.Info(f"LoRAs in cache: {lora_cache_msg}")
607
 
608
- msg_request = f"Requesting {gpu_duration_arg}s. of GPU time.\nModel: {sd_gen.model.base_model_id}"
609
- if verbose_arg:
610
  gr.Info(msg_request)
611
  print(msg_request)
612
- #yield msg_request.replace("\n", "<br>"), gr.update(), gr.update()
 
613
 
614
  start_time = time.time()
615
 
616
- # yield from sd_gen.generate_pipeline(*generation_args)
617
- #yield from dynamic_gpu_duration(
618
  return dynamic_gpu_duration(
619
  sd_gen.generate_pipeline,
620
  gpu_duration_arg,
@@ -622,19 +729,31 @@ def sd_gen_generate_pipeline(*args):
622
  )
623
 
624
  end_time = time.time()
625
- execution_time = end_time - start_time
626
- msg_task_complete = (
627
- f"GPU task complete in: {int(round(execution_time, 0) + 1)} seconds"
628
- )
629
 
630
  if verbose_arg:
 
 
 
 
631
  gr.Info(msg_task_complete)
632
  print(msg_task_complete)
633
 
634
- yield msg_task_complete, gr.update(), gr.update()
 
 
 
 
 
 
 
 
635
 
 
636
 
637
- @spaces.GPU(duration=15)
 
 
 
638
  def esrgan_upscale(image, upscaler_name, upscaler_size):
639
  if image is None: return None
640
 
@@ -656,21 +775,18 @@ def esrgan_upscale(image, upscaler_name, upscaler_size):
656
 
657
  return image_path
658
 
659
-
660
  dynamic_gpu_duration.zerogpu = True
661
  sd_gen_generate_pipeline.zerogpu = True
662
- sd_gen = GuiSD()
663
-
664
 
665
  from pathlib import Path
666
  from PIL import Image
667
  import random, json
668
  from modutils import (safe_float, escape_lora_basename, to_lora_key, to_lora_path,
669
  get_local_model_list, get_private_lora_model_lists, get_valid_lora_name,
670
- get_valid_lora_path, get_valid_lora_wt, get_lora_info, CIVITAI_SORT, CIVITAI_PERIOD, CIVITAI_BASEMODEL,
671
- normalize_prompt_list, get_civitai_info, search_lora_on_civitai, translate_to_en, get_t2i_model_info, get_civitai_tag, save_image_history)
672
-
673
 
 
674
  #@spaces.GPU
675
  def infer(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps,
676
  model_name = load_diffusers_format_model[0], lora1 = None, lora1_wt = 1.0, lora2 = None, lora2_wt = 1.0,
@@ -685,7 +801,7 @@ def infer(prompt, negative_prompt, seed, randomize_seed, width, height, guidance
685
  gpu_duration = 59
686
 
687
  images: list[tuple[PIL.Image.Image, str | None]] = []
688
- info_state = info_images = ""
689
  progress(0, desc="Preparing...")
690
 
691
  if randomize_seed:
@@ -712,7 +828,7 @@ def infer(prompt, negative_prompt, seed, randomize_seed, width, height, guidance
712
  sd_gen.load_new_model(model_name, vae, TASK_MODEL_LIST[0])
713
  progress(1, desc="Model loaded.")
714
  progress(0, desc="Starting Inference...")
715
- info_state, images, info_images = sd_gen_generate_pipeline(prompt, negative_prompt, 1, num_inference_steps,
716
  guidance_scale, True, generator, lora1, lora1_wt, lora2, lora2_wt, lora3, lora3_wt,
717
  lora4, lora4_wt, lora5, lora5_wt, sampler,
718
  height, width, model_name, vae, TASK_MODEL_LIST[0], None, "Canny", 512, 1024,
@@ -892,14 +1008,14 @@ def update_lora_dict(path: str):
892
  def download_lora(dl_urls: str):
893
  global loras_url_to_path_dict
894
  dl_path = ""
895
- before = get_local_model_list(DIRECTORY_LORAS)
896
  urls = []
897
  for url in [url.strip() for url in dl_urls.split(',')]:
898
- local_path = f"{DIRECTORY_LORAS}/{url.split('/')[-1]}"
899
  if not Path(local_path).exists():
900
- download_things(DIRECTORY_LORAS, url, HF_TOKEN, CIVITAI_API_KEY)
901
  urls.append(url)
902
- after = get_local_model_list(DIRECTORY_LORAS)
903
  new_files = list_sub(after, before)
904
  i = 0
905
  for file in new_files:
 
1
  import spaces
2
  import os
3
  from stablepy import Model_Diffusers
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  from stablepy.diffusers_vanilla.style_prompt_config import STYLE_NAMES
5
+ from stablepy.diffusers_vanilla.constants import FLUX_CN_UNION_MODES
6
  import torch
7
  import re
8
+ from huggingface_hub import HfApi
9
  from stablepy import (
10
+ CONTROLNET_MODEL_IDS,
11
+ VALID_TASKS,
12
+ T2I_PREPROCESSOR_NAME,
13
+ FLASH_LORA,
14
+ SCHEDULER_CONFIG_MAP,
15
  scheduler_names,
16
+ IP_ADAPTER_MODELS,
17
  IP_ADAPTERS_SD,
18
  IP_ADAPTERS_SDXL,
19
+ REPO_IMAGE_ENCODER,
20
+ ALL_PROMPT_WEIGHT_OPTIONS,
21
+ SD15_TASKS,
22
+ SDXL_TASKS,
23
  )
24
  import time
25
  from PIL import ImageFile
26
+ #import urllib.parse
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
  ImageFile.LOAD_TRUNCATED_IMAGES = True
 
29
  print(os.getenv("SPACES_ZERO_GPU"))
30
 
 
31
  import gradio as gr
32
  import logging
33
  logging.getLogger("diffusers").setLevel(logging.ERROR)
 
38
  warnings.filterwarnings(action="ignore", category=UserWarning, module="diffusers")
39
  warnings.filterwarnings(action="ignore", category=FutureWarning, module="transformers")
40
  from stablepy import logger
41
+ logger.setLevel(logging.CRITICAL)
42
 
43
  from env import (
44
+ HF_TOKEN, hf_read_token, # to use only for private repos
45
  CIVITAI_API_KEY, HF_LORA_PRIVATE_REPOS1, HF_LORA_PRIVATE_REPOS2,
46
  HF_LORA_ESSENTIAL_PRIVATE_REPO, HF_VAE_PRIVATE_REPO,
47
  HF_SDXL_EMBEDS_NEGATIVE_PRIVATE_REPO, HF_SDXL_EMBEDS_POSITIVE_PRIVATE_REPO,
48
+ directory_models, directory_loras, directory_vaes, directory_embeds,
49
+ directory_embeds_sdxl, directory_embeds_positive_sdxl,
50
+ load_diffusers_format_model, download_model_list, download_lora_list,
51
+ download_vae_list, download_embeds)
52
+
53
+ PREPROCESSOR_CONTROLNET = {
54
+ "openpose": [
55
+ "Openpose",
56
+ "None",
57
+ ],
58
+ "scribble": [
59
+ "HED",
60
+ "PidiNet",
61
+ "None",
62
+ ],
63
+ "softedge": [
64
+ "PidiNet",
65
+ "HED",
66
+ "HED safe",
67
+ "PidiNet safe",
68
+ "None",
69
+ ],
70
+ "segmentation": [
71
+ "UPerNet",
72
+ "None",
73
+ ],
74
+ "depth": [
75
+ "DPT",
76
+ "Midas",
77
+ "None",
78
+ ],
79
+ "normalbae": [
80
+ "NormalBae",
81
+ "None",
82
+ ],
83
+ "lineart": [
84
+ "Lineart",
85
+ "Lineart coarse",
86
+ "Lineart (anime)",
87
+ "None",
88
+ "None (anime)",
89
+ ],
90
+ "lineart_anime": [
91
+ "Lineart",
92
+ "Lineart coarse",
93
+ "Lineart (anime)",
94
+ "None",
95
+ "None (anime)",
96
+ ],
97
+ "shuffle": [
98
+ "ContentShuffle",
99
+ "None",
100
+ ],
101
+ "canny": [
102
+ "Canny",
103
+ "None",
104
+ ],
105
+ "mlsd": [
106
+ "MLSD",
107
+ "None",
108
+ ],
109
+ "ip2p": [
110
+ "ip2p"
111
+ ],
112
+ "recolor": [
113
+ "Recolor luminance",
114
+ "Recolor intensity",
115
+ "None",
116
+ ],
117
+ "tile": [
118
+ "Mild Blur",
119
+ "Moderate Blur",
120
+ "Heavy Blur",
121
+ "None",
122
+ ],
123
+ }
124
+
125
+ TASK_STABLEPY = {
126
+ 'txt2img': 'txt2img',
127
+ 'img2img': 'img2img',
128
+ 'inpaint': 'inpaint',
129
+ # 'canny T2I Adapter': 'sdxl_canny_t2i', # NO HAVE STEP CALLBACK PARAMETERS SO NOT WORKS WITH DIFFUSERS 0.29.0
130
+ # 'sketch T2I Adapter': 'sdxl_sketch_t2i',
131
+ # 'lineart T2I Adapter': 'sdxl_lineart_t2i',
132
+ # 'depth-midas T2I Adapter': 'sdxl_depth-midas_t2i',
133
+ # 'openpose T2I Adapter': 'sdxl_openpose_t2i',
134
+ 'openpose ControlNet': 'openpose',
135
+ 'canny ControlNet': 'canny',
136
+ 'mlsd ControlNet': 'mlsd',
137
+ 'scribble ControlNet': 'scribble',
138
+ 'softedge ControlNet': 'softedge',
139
+ 'segmentation ControlNet': 'segmentation',
140
+ 'depth ControlNet': 'depth',
141
+ 'normalbae ControlNet': 'normalbae',
142
+ 'lineart ControlNet': 'lineart',
143
+ 'lineart_anime ControlNet': 'lineart_anime',
144
+ 'shuffle ControlNet': 'shuffle',
145
+ 'ip2p ControlNet': 'ip2p',
146
+ 'optical pattern ControlNet': 'pattern',
147
+ 'recolor ControlNet': 'recolor',
148
+ 'tile ControlNet': 'tile',
149
+ }
150
+
151
+ TASK_MODEL_LIST = list(TASK_STABLEPY.keys())
152
+
153
+ UPSCALER_DICT_GUI = {
154
+ None: None,
155
+ "Lanczos": "Lanczos",
156
+ "Nearest": "Nearest",
157
+ 'Latent': 'Latent',
158
+ 'Latent (antialiased)': 'Latent (antialiased)',
159
+ 'Latent (bicubic)': 'Latent (bicubic)',
160
+ 'Latent (bicubic antialiased)': 'Latent (bicubic antialiased)',
161
+ 'Latent (nearest)': 'Latent (nearest)',
162
+ 'Latent (nearest-exact)': 'Latent (nearest-exact)',
163
+ "RealESRGAN_x4plus": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
164
+ "RealESRNet_x4plus": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth",
165
+ "RealESRGAN_x4plus_anime_6B": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
166
+ "RealESRGAN_x2plus": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
167
+ "realesr-animevideov3": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth",
168
+ "realesr-general-x4v3": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth",
169
+ "realesr-general-wdn-x4v3": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth",
170
+ "4x-UltraSharp": "https://huggingface.co/Shandypur/ESRGAN-4x-UltraSharp/resolve/main/4x-UltraSharp.pth",
171
+ "4x_foolhardy_Remacri": "https://huggingface.co/FacehugmanIII/4x_foolhardy_Remacri/resolve/main/4x_foolhardy_Remacri.pth",
172
+ "Remacri4xExtraSmoother": "https://huggingface.co/hollowstrawberry/upscalers-backup/resolve/main/ESRGAN/Remacri%204x%20ExtraSmoother.pth",
173
+ "AnimeSharp4x": "https://huggingface.co/hollowstrawberry/upscalers-backup/resolve/main/ESRGAN/AnimeSharp%204x.pth",
174
+ "lollypop": "https://huggingface.co/hollowstrawberry/upscalers-backup/resolve/main/ESRGAN/lollypop.pth",
175
+ "RealisticRescaler4x": "https://huggingface.co/hollowstrawberry/upscalers-backup/resolve/main/ESRGAN/RealisticRescaler%204x.pth",
176
+ "NickelbackFS4x": "https://huggingface.co/hollowstrawberry/upscalers-backup/resolve/main/ESRGAN/NickelbackFS%204x.pth"
177
+ }
178
+
179
+ UPSCALER_KEYS = list(UPSCALER_DICT_GUI.keys())
180
+
181
+
182
+ def get_model_list(directory_path):
183
+ model_list = []
184
+ valid_extensions = {'.ckpt', '.pt', '.pth', '.safetensors', '.bin'}
185
+
186
+ for filename in os.listdir(directory_path):
187
+ if os.path.splitext(filename)[1] in valid_extensions:
188
+ # name_without_extension = os.path.splitext(filename)[0]
189
+ file_path = os.path.join(directory_path, filename)
190
+ # model_list.append((name_without_extension, file_path))
191
+ model_list.append(file_path)
192
+ print('\033[34mFILE: ' + file_path + '\033[0m')
193
+ return model_list
194
 
195
+ ## BEGIN MOD
196
  from modutils import (to_list, list_uniq, list_sub, get_model_id_list, get_tupled_embed_list,
197
  get_tupled_model_list, get_lora_model_list, download_private_repo, download_things)
198
 
199
  # - **Download Models**
200
+ download_model = ", ".join(download_model_list)
201
  # - **Download VAEs**
202
+ download_vae = ", ".join(download_vae_list)
203
  # - **Download LoRAs**
204
+ download_lora = ", ".join(download_lora_list)
205
 
206
+ #download_private_repo(HF_LORA_ESSENTIAL_PRIVATE_REPO, directory_loras, True)
207
+ download_private_repo(HF_VAE_PRIVATE_REPO, directory_vaes, False)
208
 
209
+ load_diffusers_format_model = list_uniq(load_diffusers_format_model + get_model_id_list())
210
  ## END MOD
211
 
212
  # Download stuffs
213
  for url in [url.strip() for url in download_model.split(',')]:
214
  if not os.path.exists(f"./models/{url.split('/')[-1]}"):
215
+ download_things(directory_models, url, HF_TOKEN, CIVITAI_API_KEY)
216
  for url in [url.strip() for url in download_vae.split(',')]:
217
  if not os.path.exists(f"./vaes/{url.split('/')[-1]}"):
218
+ download_things(directory_vaes, url, HF_TOKEN, CIVITAI_API_KEY)
219
  for url in [url.strip() for url in download_lora.split(',')]:
220
  if not os.path.exists(f"./loras/{url.split('/')[-1]}"):
221
+ download_things(directory_loras, url, HF_TOKEN, CIVITAI_API_KEY)
222
 
223
  # Download Embeddings
224
+ for url_embed in download_embeds:
225
  if not os.path.exists(f"./embedings/{url_embed.split('/')[-1]}"):
226
+ download_things(directory_embeds, url_embed, HF_TOKEN, CIVITAI_API_KEY)
227
 
228
  # Build list models
229
+ embed_list = get_model_list(directory_embeds)
230
+ model_list = get_model_list(directory_models)
231
  model_list = load_diffusers_format_model + model_list
 
232
  ## BEGIN MOD
233
  lora_model_list = get_lora_model_list()
234
+ vae_model_list = get_model_list(directory_vaes)
235
  vae_model_list.insert(0, "None")
236
 
237
+ #download_private_repo(HF_SDXL_EMBEDS_NEGATIVE_PRIVATE_REPO, directory_embeds_sdxl, False)
238
+ #download_private_repo(HF_SDXL_EMBEDS_POSITIVE_PRIVATE_REPO, directory_embeds_positive_sdxl, False)
239
+ embed_sdxl_list = get_model_list(directory_embeds_sdxl) + get_model_list(directory_embeds_positive_sdxl)
240
 
241
  def get_embed_list(pipeline_name):
242
  return get_tupled_embed_list(embed_sdxl_list if pipeline_name == "StableDiffusionXLPipeline" else embed_list)
 
244
 
245
  print('\033[33m🏁 Download and listing of valid models completed.\033[0m')
246
 
247
+ msg_inc_vae = (
248
+ "Use the right VAE for your model to maintain image quality. The wrong"
249
+ " VAE can lead to poor results, like blurriness in the generated images."
250
+ )
251
+
252
+ SDXL_TASK = [k for k, v in TASK_STABLEPY.items() if v in SDXL_TASKS]
253
+ SD_TASK = [k for k, v in TASK_STABLEPY.items() if v in SD15_TASKS]
254
+ FLUX_TASK = list(TASK_STABLEPY.keys())[:3] + [k for k, v in TASK_STABLEPY.items() if v in FLUX_CN_UNION_MODES.keys()]
255
+
256
+ MODEL_TYPE_TASK = {
257
+ "SD 1.5": SD_TASK,
258
+ "SDXL": SDXL_TASK,
259
+ "FLUX": FLUX_TASK,
260
+ }
261
+
262
+ MODEL_TYPE_CLASS = {
263
+ "diffusers:StableDiffusionPipeline": "SD 1.5",
264
+ "diffusers:StableDiffusionXLPipeline": "SDXL",
265
+ "diffusers:FluxPipeline": "FLUX",
266
+ }
267
+
268
+ POST_PROCESSING_SAMPLER = ["Use same sampler"] + scheduler_names[:-2]
269
+
270
+ def extract_parameters(input_string):
271
+ parameters = {}
272
+ input_string = input_string.replace("\n", "")
273
+
274
+ if "Negative prompt:" not in input_string:
275
+ if "Steps:" in input_string:
276
+ input_string = input_string.replace("Steps:", "Negative prompt: Steps:")
277
+ else:
278
+ print("Invalid metadata")
279
+ parameters["prompt"] = input_string
280
+ return parameters
281
+
282
+ parm = input_string.split("Negative prompt:")
283
+ parameters["prompt"] = parm[0].strip()
284
+ if "Steps:" not in parm[1]:
285
+ print("Steps not detected")
286
+ parameters["neg_prompt"] = parm[1].strip()
287
+ return parameters
288
+ parm = parm[1].split("Steps:")
289
+ parameters["neg_prompt"] = parm[0].strip()
290
+ input_string = "Steps:" + parm[1]
291
+
292
+ # Extracting Steps
293
+ steps_match = re.search(r'Steps: (\d+)', input_string)
294
+ if steps_match:
295
+ parameters['Steps'] = int(steps_match.group(1))
296
+
297
+ # Extracting Size
298
+ size_match = re.search(r'Size: (\d+x\d+)', input_string)
299
+ if size_match:
300
+ parameters['Size'] = size_match.group(1)
301
+ width, height = map(int, parameters['Size'].split('x'))
302
+ parameters['width'] = width
303
+ parameters['height'] = height
304
+
305
+ # Extracting other parameters
306
+ other_parameters = re.findall(r'(\w+): (.*?)(?=, \w+|$)', input_string)
307
+ for param in other_parameters:
308
+ parameters[param[0]] = param[1].strip('"')
309
+
310
+ return parameters
311
+
312
+ def get_model_type(repo_id: str):
313
+ api = HfApi(token=os.environ.get("HF_TOKEN")) # if use private or gated model
314
+ default = "SD 1.5"
315
+ try:
316
+ model = api.model_info(repo_id=repo_id, timeout=5.0)
317
+ tags = model.tags
318
+ for tag in tags:
319
+ if tag in MODEL_TYPE_CLASS.keys(): return MODEL_TYPE_CLASS.get(tag, default)
320
+ except Exception:
321
+ return default
322
+ return default
323
+
324
  ## BEGIN MOD
325
  class GuiSD:
326
+ def __init__(self):
327
  self.model = None
328
+
329
+ print("Loading model...")
330
+ self.model = Model_Diffusers(
331
+ base_model_id="Lykon/dreamshaper-8",
332
+ task_name="txt2img",
333
+ vae_model=None,
334
+ type_model_precision=torch.float16,
335
+ retain_task_model_in_cache=False,
336
+ device="cpu",
337
+ )
338
+ self.model.load_beta_styles()
339
+ #self.model.device = torch.device("cpu") #
340
 
341
  def infer_short(self, model, pipe_params, progress=gr.Progress(track_tqdm=True)):
342
  #progress(0, desc="Start inference...")
 
350
  return img
351
 
352
  def load_new_model(self, model_name, vae_model, task, progress=gr.Progress(track_tqdm=True)):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
353
 
354
  #yield f"Loading model: {model_name}"
355
+
356
+ vae_model = vae_model if vae_model != "None" else None
357
+ model_type = get_model_type(model_name)
358
 
359
  if vae_model:
360
  vae_type = "SDXL" if "sdxl" in vae_model.lower() else "SD 1.5"
361
  if model_type != vae_type:
362
+ gr.Warning(msg_inc_vae)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
363
 
364
+ self.model.device = torch.device("cpu")
365
+ dtype_model = torch.bfloat16 if model_type == "FLUX" else torch.float16
366
+
367
+ self.model.load_pipe(
368
+ model_name,
369
+ task_name=TASK_STABLEPY[task],
370
+ vae_model=vae_model if vae_model != "None" else None,
371
+ type_model_precision=dtype_model,
372
+ retain_task_model_in_cache=False,
373
+ )
374
  #yield f"Model loaded: {model_name}"
375
 
376
  #@spaces.GPU
377
+ @torch.inference_mode()
378
  def generate_pipeline(
379
  self,
380
  prompt,
 
479
  mode_ip2,
480
  scale_ip2,
481
  pag_scale,
482
+ #progress=gr.Progress(track_tqdm=True),
483
  ):
484
+ #progress(0, desc="Preparing inference...")
485
+
 
486
  vae_model = vae_model if vae_model != "None" else None
487
  loras_list = [lora1, lora2, lora3, lora4, lora5]
488
  vae_msg = f"VAE: {vae_model}" if vae_model else ""
489
  msg_lora = ""
490
 
491
+ print("Config model:", model_name, vae_model, loras_list)
492
+
493
  ## BEGIN MOD
 
494
  prompt, neg_prompt = insert_model_recom_prompt(prompt, neg_prompt, model_name)
495
  global lora_model_list
496
  lora_model_list = get_lora_model_list()
497
  ## END MOD
498
 
 
 
499
  task = TASK_STABLEPY[task]
500
 
501
  params_ip_img = []
 
518
  params_ip_mode.append(modeip)
519
  params_ip_scale.append(scaleip)
520
 
 
 
 
521
  if task != "txt2img" and not image_control:
522
  raise ValueError("No control image found: To use this function, you have to upload an image in 'Image ControlNet/Inpaint/Img2img'")
523
 
 
589
  "high_threshold": high_threshold,
590
  "value_threshold": value_threshold,
591
  "distance_threshold": distance_threshold,
592
+ "lora_A": lora1 if lora1 != "None" and lora1 != "" else None,
593
  "lora_scale_A": lora_scale1,
594
+ "lora_B": lora2 if lora2 != "None" and lora2 != "" else None,
595
  "lora_scale_B": lora_scale2,
596
+ "lora_C": lora3 if lora3 != "None" and lora3 != "" else None,
597
  "lora_scale_C": lora_scale3,
598
+ "lora_D": lora4 if lora4 != "None" and lora4 != "" else None,
599
  "lora_scale_D": lora_scale4,
600
+ "lora_E": lora5 if lora5 != "None" and lora5 != "" else None,
601
  "lora_scale_E": lora_scale5,
602
  ## BEGIN MOD
603
  "textual_inversion": get_embed_list(self.model.class_name) if textual_inversion else [],
 
647
  }
648
 
649
  self.model.device = torch.device("cuda:0")
650
+ if hasattr(self.model.pipe, "transformer") and loras_list != ["None"] * 5 and loras_list != [""] * 5:
651
  self.model.pipe.transformer.to(self.model.device)
652
  print("transformer to cuda")
653
 
654
+ #progress(1, desc="Inference preparation completed. Starting inference...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
655
 
656
+ info_state = "" # for yield version
657
+ return self.infer_short(self.model, pipe_params), info_state
658
  ## END MOD
659
 
 
 
 
 
 
660
  def dynamic_gpu_duration(func, duration, *args):
661
 
 
662
  @spaces.GPU(duration=duration)
663
  def wrapped_func():
664
  return func(*args)
 
678
  load_lora_cpu = args[-3]
679
  generation_args = args[:-3]
680
  lora_list = [
681
+ None if item == "None" or item == "" else item
682
  for item in [args[7], args[9], args[11], args[13], args[15]]
683
  ]
684
  lora_status = [None] * 5
 
688
  msg_load_lora = "Updating LoRAs in CPU (Slow but saves GPU usage)..."
689
 
690
  #if lora_list != sd_gen.model.lora_memory and lora_list != [None] * 5:
691
+ # yield None, msg_load_lora
692
 
693
  # Load lora in CPU
694
  if load_lora_cpu:
 
714
  )
715
  gr.Info(f"LoRAs in cache: {lora_cache_msg}")
716
 
717
+ msg_request = f"Requesting {gpu_duration_arg}s. of GPU time"
 
718
  gr.Info(msg_request)
719
  print(msg_request)
720
+
721
+ # yield from sd_gen.generate_pipeline(*generation_args)
722
 
723
  start_time = time.time()
724
 
 
 
725
  return dynamic_gpu_duration(
726
  sd_gen.generate_pipeline,
727
  gpu_duration_arg,
 
729
  )
730
 
731
  end_time = time.time()
 
 
 
 
732
 
733
  if verbose_arg:
734
+ execution_time = end_time - start_time
735
+ msg_task_complete = (
736
+ f"GPU task complete in: {round(execution_time, 0) + 1} seconds"
737
+ )
738
  gr.Info(msg_task_complete)
739
  print(msg_task_complete)
740
 
741
+ def extract_exif_data(image):
742
+ if image is None: return ""
743
+
744
+ try:
745
+ metadata_keys = ['parameters', 'metadata', 'prompt', 'Comment']
746
+
747
+ for key in metadata_keys:
748
+ if key in image.info:
749
+ return image.info[key]
750
 
751
+ return str(image.info)
752
 
753
+ except Exception as e:
754
+ return f"Error extracting metadata: {str(e)}"
755
+
756
+ @spaces.GPU(duration=20)
757
  def esrgan_upscale(image, upscaler_name, upscaler_size):
758
  if image is None: return None
759
 
 
775
 
776
  return image_path
777
 
 
778
  dynamic_gpu_duration.zerogpu = True
779
  sd_gen_generate_pipeline.zerogpu = True
 
 
780
 
781
  from pathlib import Path
782
  from PIL import Image
783
  import random, json
784
  from modutils import (safe_float, escape_lora_basename, to_lora_key, to_lora_path,
785
  get_local_model_list, get_private_lora_model_lists, get_valid_lora_name,
786
+ get_valid_lora_path, get_valid_lora_wt, get_lora_info, CIVITAI_SORT, CIVITAI_PERIOD,
787
+ normalize_prompt_list, get_civitai_info, search_lora_on_civitai, translate_to_en)
 
788
 
789
+ sd_gen = GuiSD()
790
  #@spaces.GPU
791
  def infer(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps,
792
  model_name = load_diffusers_format_model[0], lora1 = None, lora1_wt = 1.0, lora2 = None, lora2_wt = 1.0,
 
801
  gpu_duration = 59
802
 
803
  images: list[tuple[PIL.Image.Image, str | None]] = []
804
+ info: str = ""
805
  progress(0, desc="Preparing...")
806
 
807
  if randomize_seed:
 
828
  sd_gen.load_new_model(model_name, vae, TASK_MODEL_LIST[0])
829
  progress(1, desc="Model loaded.")
830
  progress(0, desc="Starting Inference...")
831
+ images, info = sd_gen_generate_pipeline(prompt, negative_prompt, 1, num_inference_steps,
832
  guidance_scale, True, generator, lora1, lora1_wt, lora2, lora2_wt, lora3, lora3_wt,
833
  lora4, lora4_wt, lora5, lora5_wt, sampler,
834
  height, width, model_name, vae, TASK_MODEL_LIST[0], None, "Canny", 512, 1024,
 
1008
  def download_lora(dl_urls: str):
1009
  global loras_url_to_path_dict
1010
  dl_path = ""
1011
+ before = get_local_model_list(directory_loras)
1012
  urls = []
1013
  for url in [url.strip() for url in dl_urls.split(',')]:
1014
+ local_path = f"{directory_loras}/{url.split('/')[-1]}"
1015
  if not Path(local_path).exists():
1016
+ download_things(directory_loras, url, HF_TOKEN, CIVITAI_API_KEY)
1017
  urls.append(url)
1018
+ after = get_local_model_list(directory_loras)
1019
  new_files = list_sub(after, before)
1020
  i = 0
1021
  for file in new_files: