John6666 commited on
Commit
42ec22d
·
verified ·
1 Parent(s): 0f73bfb

Upload dc.py

Browse files
Files changed (1) hide show
  1. dc.py +214 -330
dc.py CHANGED
@@ -1,33 +1,52 @@
1
  import spaces
2
  import os
3
  from stablepy import Model_Diffusers
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  from stablepy.diffusers_vanilla.style_prompt_config import STYLE_NAMES
5
- from stablepy.diffusers_vanilla.constants import FLUX_CN_UNION_MODES
6
  import torch
7
  import re
8
- from huggingface_hub import HfApi
9
  from stablepy import (
10
- CONTROLNET_MODEL_IDS,
11
- VALID_TASKS,
12
- T2I_PREPROCESSOR_NAME,
13
- FLASH_LORA,
14
- SCHEDULER_CONFIG_MAP,
15
  scheduler_names,
16
- IP_ADAPTER_MODELS,
17
  IP_ADAPTERS_SD,
18
  IP_ADAPTERS_SDXL,
19
- REPO_IMAGE_ENCODER,
20
- ALL_PROMPT_WEIGHT_OPTIONS,
21
- SD15_TASKS,
22
- SDXL_TASKS,
23
  )
24
  import time
25
  from PIL import ImageFile
26
- #import urllib.parse
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
  ImageFile.LOAD_TRUNCATED_IMAGES = True
 
29
  print(os.getenv("SPACES_ZERO_GPU"))
30
 
 
31
  import gradio as gr
32
  import logging
33
  logging.getLogger("diffusers").setLevel(logging.ERROR)
@@ -38,205 +57,63 @@ warnings.filterwarnings(action="ignore", category=FutureWarning, module="diffuse
38
  warnings.filterwarnings(action="ignore", category=UserWarning, module="diffusers")
39
  warnings.filterwarnings(action="ignore", category=FutureWarning, module="transformers")
40
  from stablepy import logger
41
- logger.setLevel(logging.CRITICAL)
42
 
43
  from env import (
44
- HF_TOKEN, hf_read_token, # to use only for private repos
45
  CIVITAI_API_KEY, HF_LORA_PRIVATE_REPOS1, HF_LORA_PRIVATE_REPOS2,
46
  HF_LORA_ESSENTIAL_PRIVATE_REPO, HF_VAE_PRIVATE_REPO,
47
  HF_SDXL_EMBEDS_NEGATIVE_PRIVATE_REPO, HF_SDXL_EMBEDS_POSITIVE_PRIVATE_REPO,
48
- directory_models, directory_loras, directory_vaes, directory_embeds,
49
- directory_embeds_sdxl, directory_embeds_positive_sdxl,
50
- load_diffusers_format_model, download_model_list, download_lora_list,
51
- download_vae_list, download_embeds)
52
-
53
- PREPROCESSOR_CONTROLNET = {
54
- "openpose": [
55
- "Openpose",
56
- "None",
57
- ],
58
- "scribble": [
59
- "HED",
60
- "PidiNet",
61
- "None",
62
- ],
63
- "softedge": [
64
- "PidiNet",
65
- "HED",
66
- "HED safe",
67
- "PidiNet safe",
68
- "None",
69
- ],
70
- "segmentation": [
71
- "UPerNet",
72
- "None",
73
- ],
74
- "depth": [
75
- "DPT",
76
- "Midas",
77
- "None",
78
- ],
79
- "normalbae": [
80
- "NormalBae",
81
- "None",
82
- ],
83
- "lineart": [
84
- "Lineart",
85
- "Lineart coarse",
86
- "Lineart (anime)",
87
- "None",
88
- "None (anime)",
89
- ],
90
- "lineart_anime": [
91
- "Lineart",
92
- "Lineart coarse",
93
- "Lineart (anime)",
94
- "None",
95
- "None (anime)",
96
- ],
97
- "shuffle": [
98
- "ContentShuffle",
99
- "None",
100
- ],
101
- "canny": [
102
- "Canny",
103
- "None",
104
- ],
105
- "mlsd": [
106
- "MLSD",
107
- "None",
108
- ],
109
- "ip2p": [
110
- "ip2p"
111
- ],
112
- "recolor": [
113
- "Recolor luminance",
114
- "Recolor intensity",
115
- "None",
116
- ],
117
- "tile": [
118
- "Mild Blur",
119
- "Moderate Blur",
120
- "Heavy Blur",
121
- "None",
122
- ],
123
- }
124
-
125
- TASK_STABLEPY = {
126
- 'txt2img': 'txt2img',
127
- 'img2img': 'img2img',
128
- 'inpaint': 'inpaint',
129
- # 'canny T2I Adapter': 'sdxl_canny_t2i', # NO HAVE STEP CALLBACK PARAMETERS SO NOT WORKS WITH DIFFUSERS 0.29.0
130
- # 'sketch T2I Adapter': 'sdxl_sketch_t2i',
131
- # 'lineart T2I Adapter': 'sdxl_lineart_t2i',
132
- # 'depth-midas T2I Adapter': 'sdxl_depth-midas_t2i',
133
- # 'openpose T2I Adapter': 'sdxl_openpose_t2i',
134
- 'openpose ControlNet': 'openpose',
135
- 'canny ControlNet': 'canny',
136
- 'mlsd ControlNet': 'mlsd',
137
- 'scribble ControlNet': 'scribble',
138
- 'softedge ControlNet': 'softedge',
139
- 'segmentation ControlNet': 'segmentation',
140
- 'depth ControlNet': 'depth',
141
- 'normalbae ControlNet': 'normalbae',
142
- 'lineart ControlNet': 'lineart',
143
- 'lineart_anime ControlNet': 'lineart_anime',
144
- 'shuffle ControlNet': 'shuffle',
145
- 'ip2p ControlNet': 'ip2p',
146
- 'optical pattern ControlNet': 'pattern',
147
- 'recolor ControlNet': 'recolor',
148
- 'tile ControlNet': 'tile',
149
- }
150
-
151
- TASK_MODEL_LIST = list(TASK_STABLEPY.keys())
152
-
153
- UPSCALER_DICT_GUI = {
154
- None: None,
155
- "Lanczos": "Lanczos",
156
- "Nearest": "Nearest",
157
- 'Latent': 'Latent',
158
- 'Latent (antialiased)': 'Latent (antialiased)',
159
- 'Latent (bicubic)': 'Latent (bicubic)',
160
- 'Latent (bicubic antialiased)': 'Latent (bicubic antialiased)',
161
- 'Latent (nearest)': 'Latent (nearest)',
162
- 'Latent (nearest-exact)': 'Latent (nearest-exact)',
163
- "RealESRGAN_x4plus": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
164
- "RealESRNet_x4plus": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth",
165
- "RealESRGAN_x4plus_anime_6B": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
166
- "RealESRGAN_x2plus": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
167
- "realesr-animevideov3": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth",
168
- "realesr-general-x4v3": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth",
169
- "realesr-general-wdn-x4v3": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth",
170
- "4x-UltraSharp": "https://huggingface.co/Shandypur/ESRGAN-4x-UltraSharp/resolve/main/4x-UltraSharp.pth",
171
- "4x_foolhardy_Remacri": "https://huggingface.co/FacehugmanIII/4x_foolhardy_Remacri/resolve/main/4x_foolhardy_Remacri.pth",
172
- "Remacri4xExtraSmoother": "https://huggingface.co/hollowstrawberry/upscalers-backup/resolve/main/ESRGAN/Remacri%204x%20ExtraSmoother.pth",
173
- "AnimeSharp4x": "https://huggingface.co/hollowstrawberry/upscalers-backup/resolve/main/ESRGAN/AnimeSharp%204x.pth",
174
- "lollypop": "https://huggingface.co/hollowstrawberry/upscalers-backup/resolve/main/ESRGAN/lollypop.pth",
175
- "RealisticRescaler4x": "https://huggingface.co/hollowstrawberry/upscalers-backup/resolve/main/ESRGAN/RealisticRescaler%204x.pth",
176
- "NickelbackFS4x": "https://huggingface.co/hollowstrawberry/upscalers-backup/resolve/main/ESRGAN/NickelbackFS%204x.pth"
177
- }
178
-
179
- UPSCALER_KEYS = list(UPSCALER_DICT_GUI.keys())
180
-
181
-
182
- def get_model_list(directory_path):
183
- model_list = []
184
- valid_extensions = {'.ckpt', '.pt', '.pth', '.safetensors', '.bin'}
185
-
186
- for filename in os.listdir(directory_path):
187
- if os.path.splitext(filename)[1] in valid_extensions:
188
- # name_without_extension = os.path.splitext(filename)[0]
189
- file_path = os.path.join(directory_path, filename)
190
- # model_list.append((name_without_extension, file_path))
191
- model_list.append(file_path)
192
- print('\033[34mFILE: ' + file_path + '\033[0m')
193
- return model_list
194
 
195
- ## BEGIN MOD
196
  from modutils import (to_list, list_uniq, list_sub, get_model_id_list, get_tupled_embed_list,
197
  get_tupled_model_list, get_lora_model_list, download_private_repo, download_things)
198
 
199
  # - **Download Models**
200
- download_model = ", ".join(download_model_list)
201
  # - **Download VAEs**
202
- download_vae = ", ".join(download_vae_list)
203
  # - **Download LoRAs**
204
- download_lora = ", ".join(download_lora_list)
205
 
206
- #download_private_repo(HF_LORA_ESSENTIAL_PRIVATE_REPO, directory_loras, True)
207
- download_private_repo(HF_VAE_PRIVATE_REPO, directory_vaes, False)
208
 
209
- load_diffusers_format_model = list_uniq(load_diffusers_format_model + get_model_id_list())
210
  ## END MOD
211
 
212
  # Download stuffs
213
  for url in [url.strip() for url in download_model.split(',')]:
214
  if not os.path.exists(f"./models/{url.split('/')[-1]}"):
215
- download_things(directory_models, url, HF_TOKEN, CIVITAI_API_KEY)
216
  for url in [url.strip() for url in download_vae.split(',')]:
217
  if not os.path.exists(f"./vaes/{url.split('/')[-1]}"):
218
- download_things(directory_vaes, url, HF_TOKEN, CIVITAI_API_KEY)
219
  for url in [url.strip() for url in download_lora.split(',')]:
220
  if not os.path.exists(f"./loras/{url.split('/')[-1]}"):
221
- download_things(directory_loras, url, HF_TOKEN, CIVITAI_API_KEY)
222
 
223
  # Download Embeddings
224
- for url_embed in download_embeds:
225
  if not os.path.exists(f"./embedings/{url_embed.split('/')[-1]}"):
226
- download_things(directory_embeds, url_embed, HF_TOKEN, CIVITAI_API_KEY)
227
 
228
  # Build list models
229
- embed_list = get_model_list(directory_embeds)
230
- model_list = get_model_list(directory_models)
231
  model_list = load_diffusers_format_model + model_list
 
232
  ## BEGIN MOD
233
  lora_model_list = get_lora_model_list()
234
- vae_model_list = get_model_list(directory_vaes)
235
  vae_model_list.insert(0, "None")
236
 
237
- #download_private_repo(HF_SDXL_EMBEDS_NEGATIVE_PRIVATE_REPO, directory_embeds_sdxl, False)
238
- #download_private_repo(HF_SDXL_EMBEDS_POSITIVE_PRIVATE_REPO, directory_embeds_positive_sdxl, False)
239
- embed_sdxl_list = get_model_list(directory_embeds_sdxl) + get_model_list(directory_embeds_positive_sdxl)
240
 
241
  def get_embed_list(pipeline_name):
242
  return get_tupled_embed_list(embed_sdxl_list if pipeline_name == "StableDiffusionXLPipeline" else embed_list)
@@ -244,99 +121,13 @@ def get_embed_list(pipeline_name):
244
 
245
  print('\033[33m🏁 Download and listing of valid models completed.\033[0m')
246
 
247
- msg_inc_vae = (
248
- "Use the right VAE for your model to maintain image quality. The wrong"
249
- " VAE can lead to poor results, like blurriness in the generated images."
250
- )
251
-
252
- SDXL_TASK = [k for k, v in TASK_STABLEPY.items() if v in SDXL_TASKS]
253
- SD_TASK = [k for k, v in TASK_STABLEPY.items() if v in SD15_TASKS]
254
- FLUX_TASK = list(TASK_STABLEPY.keys())[:3] + [k for k, v in TASK_STABLEPY.items() if v in FLUX_CN_UNION_MODES.keys()]
255
-
256
- MODEL_TYPE_TASK = {
257
- "SD 1.5": SD_TASK,
258
- "SDXL": SDXL_TASK,
259
- "FLUX": FLUX_TASK,
260
- }
261
-
262
- MODEL_TYPE_CLASS = {
263
- "diffusers:StableDiffusionPipeline": "SD 1.5",
264
- "diffusers:StableDiffusionXLPipeline": "SDXL",
265
- "diffusers:FluxPipeline": "FLUX",
266
- }
267
-
268
- POST_PROCESSING_SAMPLER = ["Use same sampler"] + scheduler_names[:-2]
269
-
270
- def extract_parameters(input_string):
271
- parameters = {}
272
- input_string = input_string.replace("\n", "")
273
-
274
- if "Negative prompt:" not in input_string:
275
- if "Steps:" in input_string:
276
- input_string = input_string.replace("Steps:", "Negative prompt: Steps:")
277
- else:
278
- print("Invalid metadata")
279
- parameters["prompt"] = input_string
280
- return parameters
281
-
282
- parm = input_string.split("Negative prompt:")
283
- parameters["prompt"] = parm[0].strip()
284
- if "Steps:" not in parm[1]:
285
- print("Steps not detected")
286
- parameters["neg_prompt"] = parm[1].strip()
287
- return parameters
288
- parm = parm[1].split("Steps:")
289
- parameters["neg_prompt"] = parm[0].strip()
290
- input_string = "Steps:" + parm[1]
291
-
292
- # Extracting Steps
293
- steps_match = re.search(r'Steps: (\d+)', input_string)
294
- if steps_match:
295
- parameters['Steps'] = int(steps_match.group(1))
296
-
297
- # Extracting Size
298
- size_match = re.search(r'Size: (\d+x\d+)', input_string)
299
- if size_match:
300
- parameters['Size'] = size_match.group(1)
301
- width, height = map(int, parameters['Size'].split('x'))
302
- parameters['width'] = width
303
- parameters['height'] = height
304
-
305
- # Extracting other parameters
306
- other_parameters = re.findall(r'(\w+): (.*?)(?=, \w+|$)', input_string)
307
- for param in other_parameters:
308
- parameters[param[0]] = param[1].strip('"')
309
-
310
- return parameters
311
-
312
- def get_model_type(repo_id: str):
313
- api = HfApi(token=os.environ.get("HF_TOKEN")) # if use private or gated model
314
- default = "SD 1.5"
315
- try:
316
- model = api.model_info(repo_id=repo_id, timeout=5.0)
317
- tags = model.tags
318
- for tag in tags:
319
- if tag in MODEL_TYPE_CLASS.keys(): return MODEL_TYPE_CLASS.get(tag, default)
320
- except Exception:
321
- return default
322
- return default
323
-
324
  ## BEGIN MOD
325
  class GuiSD:
326
- def __init__(self):
327
  self.model = None
328
-
329
- print("Loading model...")
330
- self.model = Model_Diffusers(
331
- base_model_id="Lykon/dreamshaper-8",
332
- task_name="txt2img",
333
- vae_model=None,
334
- type_model_precision=torch.float16,
335
- retain_task_model_in_cache=False,
336
- device="cpu",
337
- )
338
- self.model.load_beta_styles()
339
- #self.model.device = torch.device("cpu") #
340
 
341
  def infer_short(self, model, pipe_params, progress=gr.Progress(track_tqdm=True)):
342
  #progress(0, desc="Start inference...")
@@ -350,31 +141,86 @@ class GuiSD:
350
  return img
351
 
352
  def load_new_model(self, model_name, vae_model, task, progress=gr.Progress(track_tqdm=True)):
353
-
354
- #yield f"Loading model: {model_name}"
355
-
356
  vae_model = vae_model if vae_model != "None" else None
357
  model_type = get_model_type(model_name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
358
 
359
  if vae_model:
360
  vae_type = "SDXL" if "sdxl" in vae_model.lower() else "SD 1.5"
361
  if model_type != vae_type:
362
- gr.Warning(msg_inc_vae)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
363
 
364
- self.model.device = torch.device("cpu")
365
- dtype_model = torch.bfloat16 if model_type == "FLUX" else torch.float16
366
-
367
- self.model.load_pipe(
368
- model_name,
369
- task_name=TASK_STABLEPY[task],
370
- vae_model=vae_model if vae_model != "None" else None,
371
- type_model_precision=dtype_model,
372
- retain_task_model_in_cache=False,
373
- )
374
  #yield f"Model loaded: {model_name}"
375
 
376
  #@spaces.GPU
377
- @torch.inference_mode()
378
  def generate_pipeline(
379
  self,
380
  prompt,
@@ -479,23 +325,24 @@ class GuiSD:
479
  mode_ip2,
480
  scale_ip2,
481
  pag_scale,
482
- #progress=gr.Progress(track_tqdm=True),
483
  ):
484
- #progress(0, desc="Preparing inference...")
485
-
 
486
  vae_model = vae_model if vae_model != "None" else None
487
  loras_list = [lora1, lora2, lora3, lora4, lora5]
488
  vae_msg = f"VAE: {vae_model}" if vae_model else ""
489
  msg_lora = ""
490
 
491
- print("Config model:", model_name, vae_model, loras_list)
492
-
493
  ## BEGIN MOD
 
494
  prompt, neg_prompt = insert_model_recom_prompt(prompt, neg_prompt, model_name)
495
  global lora_model_list
496
  lora_model_list = get_lora_model_list()
497
  ## END MOD
498
 
 
 
499
  task = TASK_STABLEPY[task]
500
 
501
  params_ip_img = []
@@ -518,6 +365,9 @@ class GuiSD:
518
  params_ip_mode.append(modeip)
519
  params_ip_scale.append(scaleip)
520
 
 
 
 
521
  if task != "txt2img" and not image_control:
522
  raise ValueError("No control image found: To use this function, you have to upload an image in 'Image ControlNet/Inpaint/Img2img'")
523
 
@@ -589,15 +439,15 @@ class GuiSD:
589
  "high_threshold": high_threshold,
590
  "value_threshold": value_threshold,
591
  "distance_threshold": distance_threshold,
592
- "lora_A": lora1 if lora1 != "None" and lora1 != "" else None,
593
  "lora_scale_A": lora_scale1,
594
- "lora_B": lora2 if lora2 != "None" and lora2 != "" else None,
595
  "lora_scale_B": lora_scale2,
596
- "lora_C": lora3 if lora3 != "None" and lora3 != "" else None,
597
  "lora_scale_C": lora_scale3,
598
- "lora_D": lora4 if lora4 != "None" and lora4 != "" else None,
599
  "lora_scale_D": lora_scale4,
600
- "lora_E": lora5 if lora5 != "None" and lora5 != "" else None,
601
  "lora_scale_E": lora_scale5,
602
  ## BEGIN MOD
603
  "textual_inversion": get_embed_list(self.model.class_name) if textual_inversion else [],
@@ -647,18 +497,59 @@ class GuiSD:
647
  }
648
 
649
  self.model.device = torch.device("cuda:0")
650
- if hasattr(self.model.pipe, "transformer") and loras_list != ["None"] * 5 and loras_list != [""] * 5:
651
  self.model.pipe.transformer.to(self.model.device)
652
  print("transformer to cuda")
653
 
654
- #progress(1, desc="Inference preparation completed. Starting inference...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
655
 
656
- info_state = "" # for yield version
657
- return self.infer_short(self.model, pipe_params), info_state
658
  ## END MOD
659
 
 
 
 
 
 
660
  def dynamic_gpu_duration(func, duration, *args):
661
 
 
662
  @spaces.GPU(duration=duration)
663
  def wrapped_func():
664
  return func(*args)
@@ -678,7 +569,7 @@ def sd_gen_generate_pipeline(*args):
678
  load_lora_cpu = args[-3]
679
  generation_args = args[:-3]
680
  lora_list = [
681
- None if item == "None" or item == "" else item
682
  for item in [args[7], args[9], args[11], args[13], args[15]]
683
  ]
684
  lora_status = [None] * 5
@@ -688,7 +579,7 @@ def sd_gen_generate_pipeline(*args):
688
  msg_load_lora = "Updating LoRAs in CPU (Slow but saves GPU usage)..."
689
 
690
  #if lora_list != sd_gen.model.lora_memory and lora_list != [None] * 5:
691
- # yield None, msg_load_lora
692
 
693
  # Load lora in CPU
694
  if load_lora_cpu:
@@ -714,14 +605,16 @@ def sd_gen_generate_pipeline(*args):
714
  )
715
  gr.Info(f"LoRAs in cache: {lora_cache_msg}")
716
 
717
- msg_request = f"Requesting {gpu_duration_arg}s. of GPU time"
 
718
  gr.Info(msg_request)
719
  print(msg_request)
720
-
721
- # yield from sd_gen.generate_pipeline(*generation_args)
722
 
723
  start_time = time.time()
724
 
 
 
725
  return dynamic_gpu_duration(
726
  sd_gen.generate_pipeline,
727
  gpu_duration_arg,
@@ -729,31 +622,19 @@ def sd_gen_generate_pipeline(*args):
729
  )
730
 
731
  end_time = time.time()
 
 
 
 
732
 
733
  if verbose_arg:
734
- execution_time = end_time - start_time
735
- msg_task_complete = (
736
- f"GPU task complete in: {round(execution_time, 0) + 1} seconds"
737
- )
738
  gr.Info(msg_task_complete)
739
  print(msg_task_complete)
740
 
741
- def extract_exif_data(image):
742
- if image is None: return ""
743
-
744
- try:
745
- metadata_keys = ['parameters', 'metadata', 'prompt', 'Comment']
746
 
747
- for key in metadata_keys:
748
- if key in image.info:
749
- return image.info[key]
750
 
751
- return str(image.info)
752
-
753
- except Exception as e:
754
- return f"Error extracting metadata: {str(e)}"
755
-
756
- @spaces.GPU(duration=20)
757
  def esrgan_upscale(image, upscaler_name, upscaler_size):
758
  if image is None: return None
759
 
@@ -775,18 +656,21 @@ def esrgan_upscale(image, upscaler_name, upscaler_size):
775
 
776
  return image_path
777
 
 
778
  dynamic_gpu_duration.zerogpu = True
779
  sd_gen_generate_pipeline.zerogpu = True
 
 
780
 
781
  from pathlib import Path
782
  from PIL import Image
783
  import random, json
784
  from modutils import (safe_float, escape_lora_basename, to_lora_key, to_lora_path,
785
  get_local_model_list, get_private_lora_model_lists, get_valid_lora_name,
786
- get_valid_lora_path, get_valid_lora_wt, get_lora_info, CIVITAI_SORT, CIVITAI_PERIOD,
787
- normalize_prompt_list, get_civitai_info, search_lora_on_civitai, translate_to_en)
 
788
 
789
- sd_gen = GuiSD()
790
  #@spaces.GPU
791
  def infer(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps,
792
  model_name = load_diffusers_format_model[0], lora1 = None, lora1_wt = 1.0, lora2 = None, lora2_wt = 1.0,
@@ -801,7 +685,7 @@ def infer(prompt, negative_prompt, seed, randomize_seed, width, height, guidance
801
  gpu_duration = 59
802
 
803
  images: list[tuple[PIL.Image.Image, str | None]] = []
804
- info: str = ""
805
  progress(0, desc="Preparing...")
806
 
807
  if randomize_seed:
@@ -828,7 +712,7 @@ def infer(prompt, negative_prompt, seed, randomize_seed, width, height, guidance
828
  sd_gen.load_new_model(model_name, vae, TASK_MODEL_LIST[0])
829
  progress(1, desc="Model loaded.")
830
  progress(0, desc="Starting Inference...")
831
- images, info = sd_gen_generate_pipeline(prompt, negative_prompt, 1, num_inference_steps,
832
  guidance_scale, True, generator, lora1, lora1_wt, lora2, lora2_wt, lora3, lora3_wt,
833
  lora4, lora4_wt, lora5, lora5_wt, sampler,
834
  height, width, model_name, vae, TASK_MODEL_LIST[0], None, "Canny", 512, 1024,
@@ -1008,14 +892,14 @@ def update_lora_dict(path: str):
1008
  def download_lora(dl_urls: str):
1009
  global loras_url_to_path_dict
1010
  dl_path = ""
1011
- before = get_local_model_list(directory_loras)
1012
  urls = []
1013
  for url in [url.strip() for url in dl_urls.split(',')]:
1014
- local_path = f"{directory_loras}/{url.split('/')[-1]}"
1015
  if not Path(local_path).exists():
1016
- download_things(directory_loras, url, HF_TOKEN, CIVITAI_API_KEY)
1017
  urls.append(url)
1018
- after = get_local_model_list(directory_loras)
1019
  new_files = list_sub(after, before)
1020
  i = 0
1021
  for file in new_files:
 
1
  import spaces
2
  import os
3
  from stablepy import Model_Diffusers
4
+ from constants import (
5
+ PREPROCESSOR_CONTROLNET,
6
+ TASK_STABLEPY,
7
+ TASK_MODEL_LIST,
8
+ UPSCALER_DICT_GUI,
9
+ UPSCALER_KEYS,
10
+ PROMPT_W_OPTIONS,
11
+ WARNING_MSG_VAE,
12
+ SDXL_TASK,
13
+ MODEL_TYPE_TASK,
14
+ POST_PROCESSING_SAMPLER,
15
+
16
+ )
17
  from stablepy.diffusers_vanilla.style_prompt_config import STYLE_NAMES
 
18
  import torch
19
  import re
 
20
  from stablepy import (
 
 
 
 
 
21
  scheduler_names,
 
22
  IP_ADAPTERS_SD,
23
  IP_ADAPTERS_SDXL,
 
 
 
 
24
  )
25
  import time
26
  from PIL import ImageFile
27
+ from utils import (
28
+ get_model_list,
29
+ extract_parameters,
30
+ get_model_type,
31
+ extract_exif_data,
32
+ create_mask_now,
33
+ download_diffuser_repo,
34
+ progress_step_bar,
35
+ html_template_message,
36
+ )
37
+ from datetime import datetime
38
+ import gradio as gr
39
+ import logging
40
+ import diffusers
41
+ import warnings
42
+ from stablepy import logger
43
+ # import urllib.parse
44
 
45
  ImageFile.LOAD_TRUNCATED_IMAGES = True
46
+ # os.environ["PYTORCH_NO_CUDA_MEMORY_CACHING"] = "1"
47
  print(os.getenv("SPACES_ZERO_GPU"))
48
 
49
+ ## BEGIN MOD
50
  import gradio as gr
51
  import logging
52
  logging.getLogger("diffusers").setLevel(logging.ERROR)
 
57
  warnings.filterwarnings(action="ignore", category=UserWarning, module="diffusers")
58
  warnings.filterwarnings(action="ignore", category=FutureWarning, module="transformers")
59
  from stablepy import logger
60
+ logger.setLevel(logging.DEBUG)
61
 
62
  from env import (
63
+ HF_TOKEN, HF_READ_TOKEN, # to use only for private repos
64
  CIVITAI_API_KEY, HF_LORA_PRIVATE_REPOS1, HF_LORA_PRIVATE_REPOS2,
65
  HF_LORA_ESSENTIAL_PRIVATE_REPO, HF_VAE_PRIVATE_REPO,
66
  HF_SDXL_EMBEDS_NEGATIVE_PRIVATE_REPO, HF_SDXL_EMBEDS_POSITIVE_PRIVATE_REPO,
67
+ DIRECTORY_MODELS, DIRECTORY_LORAS, DIRECTORY_VAES, DIRECTORY_EMBEDS,
68
+ DIRECTORY_EMBEDS_SDXL, DIRECTORY_EMBEDS_POSITIVE_SDXL,
69
+ LOAD_DIFFUSERS_FORMAT_MODEL, DOWNLOAD_MODEL_LIST, DOWNLOAD_LORA_LIST,
70
+ DOWNLOAD_VAE_LIST, DOWNLOAD_EMBEDS)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
 
72
  from modutils import (to_list, list_uniq, list_sub, get_model_id_list, get_tupled_embed_list,
73
  get_tupled_model_list, get_lora_model_list, download_private_repo, download_things)
74
 
75
  # - **Download Models**
76
+ download_model = ", ".join(DOWNLOAD_MODEL_LIST)
77
  # - **Download VAEs**
78
+ download_vae = ", ".join(DOWNLOAD_VAE_LIST)
79
  # - **Download LoRAs**
80
+ download_lora = ", ".join(DOWNLOAD_LORA_LIST)
81
 
82
+ #download_private_repo(HF_LORA_ESSENTIAL_PRIVATE_REPO, DIRECTORY_LORAS, True)
83
+ download_private_repo(HF_VAE_PRIVATE_REPO, DIRECTORY_VAES, False)
84
 
85
+ load_diffusers_format_model = list_uniq(LOAD_DIFFUSERS_FORMAT_MODEL + get_model_id_list())
86
  ## END MOD
87
 
88
  # Download stuffs
89
  for url in [url.strip() for url in download_model.split(',')]:
90
  if not os.path.exists(f"./models/{url.split('/')[-1]}"):
91
+ download_things(DIRECTORY_MODELS, url, HF_TOKEN, CIVITAI_API_KEY)
92
  for url in [url.strip() for url in download_vae.split(',')]:
93
  if not os.path.exists(f"./vaes/{url.split('/')[-1]}"):
94
+ download_things(DIRECTORY_VAES, url, HF_TOKEN, CIVITAI_API_KEY)
95
  for url in [url.strip() for url in download_lora.split(',')]:
96
  if not os.path.exists(f"./loras/{url.split('/')[-1]}"):
97
+ download_things(DIRECTORY_LORAS, url, HF_TOKEN, CIVITAI_API_KEY)
98
 
99
  # Download Embeddings
100
+ for url_embed in DOWNLOAD_EMBEDS:
101
  if not os.path.exists(f"./embedings/{url_embed.split('/')[-1]}"):
102
+ download_things(DIRECTORY_EMBEDS, url_embed, HF_TOKEN, CIVITAI_API_KEY)
103
 
104
  # Build list models
105
+ embed_list = get_model_list(DIRECTORY_EMBEDS)
106
+ model_list = get_model_list(DIRECTORY_MODELS)
107
  model_list = load_diffusers_format_model + model_list
108
+
109
  ## BEGIN MOD
110
  lora_model_list = get_lora_model_list()
111
+ vae_model_list = get_model_list(DIRECTORY_VAES)
112
  vae_model_list.insert(0, "None")
113
 
114
+ #download_private_repo(HF_SDXL_EMBEDS_NEGATIVE_PRIVATE_REPO, DIRECTORY_EMBEDS_SDXL, False)
115
+ #download_private_repo(HF_SDXL_EMBEDS_POSITIVE_PRIVATE_REPO, DIRECTORY_EMBEDS_POSITIVE_SDXL, False)
116
+ embed_sdxl_list = get_model_list(DIRECTORY_EMBEDS_SDXL) + get_model_list(DIRECTORY_EMBEDS_POSITIVE_SDXL)
117
 
118
  def get_embed_list(pipeline_name):
119
  return get_tupled_embed_list(embed_sdxl_list if pipeline_name == "StableDiffusionXLPipeline" else embed_list)
 
121
 
122
  print('\033[33m🏁 Download and listing of valid models completed.\033[0m')
123
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
  ## BEGIN MOD
125
  class GuiSD:
126
+ def __init__(self, stream=True):
127
  self.model = None
128
+ self.status_loading = False
129
+ self.sleep_loading = 4
130
+ self.last_load = datetime.now()
 
 
 
 
 
 
 
 
 
131
 
132
  def infer_short(self, model, pipe_params, progress=gr.Progress(track_tqdm=True)):
133
  #progress(0, desc="Start inference...")
 
141
  return img
142
 
143
  def load_new_model(self, model_name, vae_model, task, progress=gr.Progress(track_tqdm=True)):
 
 
 
144
  vae_model = vae_model if vae_model != "None" else None
145
  model_type = get_model_type(model_name)
146
+ dtype_model = torch.bfloat16 if model_type == "FLUX" else torch.float16
147
+
148
+ if not os.path.exists(model_name):
149
+ _ = download_diffuser_repo(
150
+ repo_name=model_name,
151
+ model_type=model_type,
152
+ revision="main",
153
+ token=True,
154
+ )
155
+
156
+ '''for i in range(68):
157
+ if not self.status_loading:
158
+ self.status_loading = True
159
+ if i > 0:
160
+ time.sleep(self.sleep_loading)
161
+ print("Previous model ops...")
162
+ break
163
+ time.sleep(0.5)
164
+ print(f"Waiting queue {i}")
165
+ yield "Waiting queue"
166
+
167
+ self.status_loading = True
168
+
169
+ yield f"Loading model: {model_name}"'''
170
 
171
  if vae_model:
172
  vae_type = "SDXL" if "sdxl" in vae_model.lower() else "SD 1.5"
173
  if model_type != vae_type:
174
+ gr.Warning(WARNING_MSG_VAE)
175
+
176
+ print("Loading model...")
177
+
178
+ try:
179
+ start_time = time.time()
180
+
181
+ if self.model is None:
182
+ self.model = Model_Diffusers(
183
+ base_model_id=model_name,
184
+ task_name=TASK_STABLEPY[task],
185
+ vae_model=vae_model,
186
+ type_model_precision=dtype_model,
187
+ retain_task_model_in_cache=False,
188
+ device="cpu",
189
+ )
190
+ else:
191
+
192
+ if self.model.base_model_id != model_name:
193
+ load_now_time = datetime.now()
194
+ elapsed_time = (load_now_time - self.last_load).total_seconds()
195
+
196
+ if elapsed_time <= 8:
197
+ print("Waiting for the previous model's time ops...")
198
+ time.sleep(8-elapsed_time)
199
+
200
+ self.model.device = torch.device("cpu")
201
+ self.model.load_pipe(
202
+ model_name,
203
+ task_name=TASK_STABLEPY[task],
204
+ vae_model=vae_model,
205
+ type_model_precision=dtype_model,
206
+ retain_task_model_in_cache=False,
207
+ )
208
+
209
+ end_time = time.time()
210
+ self.sleep_loading = max(min(int(end_time - start_time), 10), 4)
211
+ except Exception as e:
212
+ self.last_load = datetime.now()
213
+ self.status_loading = False
214
+ self.sleep_loading = 4
215
+ raise e
216
+
217
+ self.last_load = datetime.now()
218
+ self.status_loading = False
219
 
 
 
 
 
 
 
 
 
 
 
220
  #yield f"Model loaded: {model_name}"
221
 
222
  #@spaces.GPU
223
+ #@torch.inference_mode()
224
  def generate_pipeline(
225
  self,
226
  prompt,
 
325
  mode_ip2,
326
  scale_ip2,
327
  pag_scale,
 
328
  ):
329
+ info_state = html_template_message("Navigating latent space...")
330
+ #yield info_state, gr.update(), gr.update()
331
+
332
  vae_model = vae_model if vae_model != "None" else None
333
  loras_list = [lora1, lora2, lora3, lora4, lora5]
334
  vae_msg = f"VAE: {vae_model}" if vae_model else ""
335
  msg_lora = ""
336
 
 
 
337
  ## BEGIN MOD
338
+ loras_list = [s if s else "None" for s in loras_list]
339
  prompt, neg_prompt = insert_model_recom_prompt(prompt, neg_prompt, model_name)
340
  global lora_model_list
341
  lora_model_list = get_lora_model_list()
342
  ## END MOD
343
 
344
+ print("Config model:", model_name, vae_model, loras_list)
345
+
346
  task = TASK_STABLEPY[task]
347
 
348
  params_ip_img = []
 
365
  params_ip_mode.append(modeip)
366
  params_ip_scale.append(scaleip)
367
 
368
+ concurrency = 5
369
+ self.model.stream_config(concurrency=concurrency, latent_resize_by=1, vae_decoding=False)
370
+
371
  if task != "txt2img" and not image_control:
372
  raise ValueError("No control image found: To use this function, you have to upload an image in 'Image ControlNet/Inpaint/Img2img'")
373
 
 
439
  "high_threshold": high_threshold,
440
  "value_threshold": value_threshold,
441
  "distance_threshold": distance_threshold,
442
+ "lora_A": lora1 if lora1 != "None" else None,
443
  "lora_scale_A": lora_scale1,
444
+ "lora_B": lora2 if lora2 != "None" else None,
445
  "lora_scale_B": lora_scale2,
446
+ "lora_C": lora3 if lora3 != "None" else None,
447
  "lora_scale_C": lora_scale3,
448
+ "lora_D": lora4 if lora4 != "None" else None,
449
  "lora_scale_D": lora_scale4,
450
+ "lora_E": lora5 if lora5 != "None" else None,
451
  "lora_scale_E": lora_scale5,
452
  ## BEGIN MOD
453
  "textual_inversion": get_embed_list(self.model.class_name) if textual_inversion else [],
 
497
  }
498
 
499
  self.model.device = torch.device("cuda:0")
500
+ if hasattr(self.model.pipe, "transformer") and loras_list != ["None"] * 5:
501
  self.model.pipe.transformer.to(self.model.device)
502
  print("transformer to cuda")
503
 
504
+ #return self.infer_short(self.model, pipe_params), info_state
505
+
506
+ actual_progress = 0
507
+ info_images = gr.update()
508
+ for img, seed, image_path, metadata in self.model(**pipe_params):
509
+ info_state = progress_step_bar(actual_progress, steps)
510
+ actual_progress += concurrency
511
+ if image_path:
512
+ info_images = f"Seeds: {str(seed)}"
513
+ if vae_msg:
514
+ info_images = info_images + "<br>" + vae_msg
515
+
516
+ if "Cannot copy out of meta tensor; no data!" in self.model.last_lora_error:
517
+ msg_ram = "Unable to process the LoRAs due to high RAM usage; please try again later."
518
+ print(msg_ram)
519
+ msg_lora += f"<br>{msg_ram}"
520
+
521
+ for status, lora in zip(self.model.lora_status, self.model.lora_memory):
522
+ if status:
523
+ msg_lora += f"<br>Loaded: {lora}"
524
+ elif status is not None:
525
+ msg_lora += f"<br>Error with: {lora}"
526
+
527
+ if msg_lora:
528
+ info_images += msg_lora
529
+
530
+ info_images = info_images + "<br>" + "GENERATION DATA:<br>" + metadata[0].replace("\n", "<br>") + "<br>-------<br>"
531
+
532
+ download_links = "<br>".join(
533
+ [
534
+ f'<a href="{path.replace("/images/", "/file=/home/user/app/images/")}" download="{os.path.basename(path)}">Download Image {i + 1}</a>'
535
+ for i, path in enumerate(image_path)
536
+ ]
537
+ )
538
+ if save_generated_images:
539
+ info_images += f"<br>{download_links}"
540
 
541
+ ## BEGIN MOD
542
+ img = save_images(img, metadata)
543
  ## END MOD
544
 
545
+ info_state = "COMPLETE"
546
+
547
+ #yield info_state, img, info_images
548
+ return info_state, img, info_images
549
+
550
  def dynamic_gpu_duration(func, duration, *args):
551
 
552
+ @torch.inference_mode()
553
  @spaces.GPU(duration=duration)
554
  def wrapped_func():
555
  return func(*args)
 
569
  load_lora_cpu = args[-3]
570
  generation_args = args[:-3]
571
  lora_list = [
572
+ None if item == "None" or item == "" else item # MOD
573
  for item in [args[7], args[9], args[11], args[13], args[15]]
574
  ]
575
  lora_status = [None] * 5
 
579
  msg_load_lora = "Updating LoRAs in CPU (Slow but saves GPU usage)..."
580
 
581
  #if lora_list != sd_gen.model.lora_memory and lora_list != [None] * 5:
582
+ # yield msg_load_lora, gr.update(), gr.update()
583
 
584
  # Load lora in CPU
585
  if load_lora_cpu:
 
605
  )
606
  gr.Info(f"LoRAs in cache: {lora_cache_msg}")
607
 
608
+ msg_request = f"Requesting {gpu_duration_arg}s. of GPU time.\nModel: {sd_gen.model.base_model_id}"
609
+ if verbose_arg:
610
  gr.Info(msg_request)
611
  print(msg_request)
612
+ #yield msg_request.replace("\n", "<br>"), gr.update(), gr.update()
 
613
 
614
  start_time = time.time()
615
 
616
+ # yield from sd_gen.generate_pipeline(*generation_args)
617
+ #yield from dynamic_gpu_duration(
618
  return dynamic_gpu_duration(
619
  sd_gen.generate_pipeline,
620
  gpu_duration_arg,
 
622
  )
623
 
624
  end_time = time.time()
625
+ execution_time = end_time - start_time
626
+ msg_task_complete = (
627
+ f"GPU task complete in: {int(round(execution_time, 0) + 1)} seconds"
628
+ )
629
 
630
  if verbose_arg:
 
 
 
 
631
  gr.Info(msg_task_complete)
632
  print(msg_task_complete)
633
 
634
+ yield msg_task_complete, gr.update(), gr.update()
 
 
 
 
635
 
 
 
 
636
 
637
+ @spaces.GPU(duration=15)
 
 
 
 
 
638
  def esrgan_upscale(image, upscaler_name, upscaler_size):
639
  if image is None: return None
640
 
 
656
 
657
  return image_path
658
 
659
+
660
  dynamic_gpu_duration.zerogpu = True
661
  sd_gen_generate_pipeline.zerogpu = True
662
+ sd_gen = GuiSD()
663
+
664
 
665
  from pathlib import Path
666
  from PIL import Image
667
  import random, json
668
  from modutils import (safe_float, escape_lora_basename, to_lora_key, to_lora_path,
669
  get_local_model_list, get_private_lora_model_lists, get_valid_lora_name,
670
+ get_valid_lora_path, get_valid_lora_wt, get_lora_info, CIVITAI_SORT, CIVITAI_PERIOD, CIVITAI_BASEMODEL,
671
+ normalize_prompt_list, get_civitai_info, search_lora_on_civitai, translate_to_en, get_t2i_model_info, get_civitai_tag, save_image_history)
672
+
673
 
 
674
  #@spaces.GPU
675
  def infer(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps,
676
  model_name = load_diffusers_format_model[0], lora1 = None, lora1_wt = 1.0, lora2 = None, lora2_wt = 1.0,
 
685
  gpu_duration = 59
686
 
687
  images: list[tuple[PIL.Image.Image, str | None]] = []
688
+ info_state = info_images = ""
689
  progress(0, desc="Preparing...")
690
 
691
  if randomize_seed:
 
712
  sd_gen.load_new_model(model_name, vae, TASK_MODEL_LIST[0])
713
  progress(1, desc="Model loaded.")
714
  progress(0, desc="Starting Inference...")
715
+ info_state, images, info_images = sd_gen_generate_pipeline(prompt, negative_prompt, 1, num_inference_steps,
716
  guidance_scale, True, generator, lora1, lora1_wt, lora2, lora2_wt, lora3, lora3_wt,
717
  lora4, lora4_wt, lora5, lora5_wt, sampler,
718
  height, width, model_name, vae, TASK_MODEL_LIST[0], None, "Canny", 512, 1024,
 
892
  def download_lora(dl_urls: str):
893
  global loras_url_to_path_dict
894
  dl_path = ""
895
+ before = get_local_model_list(DIRECTORY_LORAS)
896
  urls = []
897
  for url in [url.strip() for url in dl_urls.split(',')]:
898
+ local_path = f"{DIRECTORY_LORAS}/{url.split('/')[-1]}"
899
  if not Path(local_path).exists():
900
+ download_things(DIRECTORY_LORAS, url, HF_TOKEN, CIVITAI_API_KEY)
901
  urls.append(url)
902
+ after = get_local_model_list(DIRECTORY_LORAS)
903
  new_files = list_sub(after, before)
904
  i = 0
905
  for file in new_files: