Duskfallcrew commited on
Commit
2432f38
·
verified ·
1 Parent(s): 6a975c0

Update app.py

Browse files

UGH THE CONVERSION AND DICT WAS NOT IN THERE BEFORE LAZY LLM

Files changed (1) hide show
  1. app.py +204 -45
app.py CHANGED
@@ -1,35 +1,68 @@
1
  import os
2
  import gradio as gr
3
  import torch
4
- from diffusers import StableDiffusionXLPipeline
5
- from huggingface_hub import HfApi, login
6
- from huggingface_hub.utils import validate_repo_id, HfHubHTTPError
7
- import tempfile
8
  import re
9
  import json
10
- import glob
11
  import gdown
12
  import requests
13
  import subprocess
14
  from urllib.parse import urlparse, unquote
15
  from pathlib import Path
 
 
16
 
17
  # ---------------------- UTILITY FUNCTIONS ----------------------
18
 
19
- def get_save_dtype(save_precision_as):
20
- """Determines the save dtype based on the user's choice."""
21
- if save_precision_as == "fp16":
22
- return torch.float16
23
- elif save_precision_as == "bf16":
24
- return torch.bfloat16
25
- elif save_precision_as == "float":
26
- return torch.float32 # Using float32 for "float" option
 
 
 
 
 
 
 
27
  else:
28
- return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
  def determine_load_checkpoint(model_to_load):
31
- """Determines if the model to load is a checkpoint or a Diffusers model."""
32
- if model_to_load.endswith('.ckpt') or model_to_load.endswith('.safetensors'):
 
 
33
  return True
34
  elif os.path.isdir(model_to_load):
35
  required_folders = {"unet", "text_encoder", "text_encoder_2", "tokenizer", "tokenizer_2", "scheduler", "vae"}
@@ -37,15 +70,6 @@ def determine_load_checkpoint(model_to_load):
37
  return False
38
  return None # handle this case as required
39
 
40
- def increment_filename(filename):
41
- """Increments the filename to avoid overwriting existing files."""
42
- base, ext = os.path.splitext(filename)
43
- counter = 1
44
- while os.path.exists(filename):
45
- filename = f"{base}({counter}){ext}"
46
- counter += 1
47
- return filename
48
-
49
  def create_model_repo(api, user, orgs_name, model_name, make_private=False):
50
  """Creates a Hugging Face model repository if it doesn't exist."""
51
  if orgs_name == "":
@@ -69,6 +93,112 @@ def is_diffusers_model(model_path):
69
  required_folders = {"unet", "text_encoder", "text_encoder_2", "tokenizer", "tokenizer_2", "scheduler", "vae"}
70
  return required_folders.issubset(set(os.listdir(model_path))) and os.path.isfile(os.path.join(model_path, "model_index.json"))
71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  # ---------------------- CONVERSION AND UPLOAD FUNCTIONS ----------------------
73
 
74
  def load_sdxl_model(args, is_load_checkpoint, load_dtype, output_widget):
@@ -86,15 +216,43 @@ def load_sdxl_model(args, is_load_checkpoint, load_dtype, output_widget):
86
 
87
  def load_from_sdxl_checkpoint(args, output_widget):
88
  """Loads the SDXL model components from a checkpoint file (placeholder)."""
89
- # text_encoder1, text_encoder2, vae, unet, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint(
90
- # "sdxl_base_v1-0", args.model_to_load, "cpu"
91
- # )
92
-
93
- # Implement Load model from ckpt or safetensors
94
  text_encoder1, text_encoder2, vae, unet = None, None, None, None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
 
96
- with output_widget:
97
- print("Loading from Checkpoint not implemented, please implement based on your model needs.")
 
 
 
 
 
 
 
 
 
 
 
 
 
98
 
99
  return text_encoder1, text_encoder2, vae, unet
100
 
@@ -125,16 +283,14 @@ def convert_and_save_sdxl_model(args, is_save_checkpoint, loaded_model_data, sav
125
 
126
  def save_sdxl_as_checkpoint(args, text_encoder1, text_encoder2, vae, unet, save_dtype, output_widget):
127
  """Saves the SDXL model components as a checkpoint file (placeholder)."""
128
- # logit_scale = None
129
- # ckpt_info = None
130
-
131
- # key_count = sdxl_model_util.save_stable_diffusion_checkpoint(
132
- # args.model_to_save, text_encoder1, text_encoder2, unet, args.epoch, args.global_step, ckpt_info, vae, logit_scale, save_dtype
133
- # )
134
 
 
 
 
135
  with output_widget:
136
- print("Saving as Checkpoint not implemented, please implement based on your model needs.")
137
- # print(f"Model saved. Total converted state_dict keys: {key_count}")
138
 
139
  def save_sdxl_as_diffusers(args, text_encoder1, text_encoder2, vae, unet, save_dtype, output_widget):
140
  """Saves the SDXL model as a Diffusers model."""
@@ -170,7 +326,6 @@ def convert_model(model_to_load, save_precision_as, epoch, global_step, referenc
170
  self.output_path = output_path #Using output_path even if hardcoded
171
  self.fp16 = fp16
172
 
173
- # Create a temporary directory for output
174
  with tempfile.TemporaryDirectory() as tmpdirname:
175
  args = Args(model_to_load, save_precision_as, epoch, global_step, reference_model, tmpdirname, fp16)
176
  args.model_to_save = increment_filename(os.path.splitext(args.model_to_load)[0] + ".safetensors")
@@ -246,8 +401,6 @@ def main(model_to_load, save_precision_as, epoch, global_step, reference_model,
246
  """Main function orchestrating the entire process."""
247
  output = gr.Markdown()
248
 
249
- # Hardcode output_path
250
- #output_path = "./converted_model" ##This is incorrect! This will save to current working directory, which isnt ideal.
251
  # Create tempdir, will only be there for the function
252
  with tempfile.TemporaryDirectory() as output_path:
253
  conversion_output = convert_model(model_to_load, save_precision_as, epoch, global_step, reference_model, fp16, output)
@@ -263,6 +416,13 @@ with gr.Blocks() as demo:
263
  gr.Markdown(f"""
264
  ## **⚠️ IMPORTANT WARNINGS ⚠️**
265
  This App is Coded by an LLM partially, and for more information please go here: [Ktiseos Nyx](https://github.com/Ktiseos-Nyx/Sdxl-to-diffusers). The colab edition of this may indeed break AUP. This space is running on CPU and in theory SHOULD work, but may be slow. Earth and Dusk/ Ktiseos Nyx does not have the enterprise budget for ZERO GPU or any gpu sadly! Thank you to the community, John6666 especially for coming to aid when gemini would NOT fix the requirements. Support Ktiseos Nyx & Myself on Ko-fi: [![Ko-fi](https://img.shields.io/badge/Support%20me%20on%20Ko--fi-F16061?logo=ko-fi&logoColor=white&style=flat)](https://ko-fi.com/Z8Z8L4EO)
 
 
 
 
 
 
 
266
  """)
267
 
268
  model_to_load = gr.Textbox(label="Model to Load (Checkpoint or Diffusers)", placeholder="Path to model")
@@ -277,10 +437,9 @@ with gr.Blocks() as demo:
277
 
278
  reference_model = gr.Textbox(label="Reference Diffusers Model",
279
  placeholder="e.g., stabilityai/stable-diffusion-xl-base-1.0")
280
- #output_path = gr.Textbox(label="Output Path", value="./converted_model") #Remove text box - using temp file approach
281
 
282
  gr.Markdown("## Hugging Face Hub Configuration")
283
- hf_token = gr.Textbox(label="Hugging Face Token", placeholder="Your Hugging Face write token")
284
  with gr.Row():
285
  orgs_name = gr.Textbox(label="Organization Name (Optional)", placeholder="Your organization name")
286
  model_name = gr.Textbox(label="Model Name", placeholder="The name of your model on Hugging Face")
 
1
  import os
2
  import gradio as gr
3
  import torch
4
+ from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel, AutoencoderKL
5
+ from transformers import CLIPTextModel, CLIPTextConfig
6
+ from safetensors.torch import load_file
7
+ from collections import OrderedDict
8
  import re
9
  import json
 
10
  import gdown
11
  import requests
12
  import subprocess
13
  from urllib.parse import urlparse, unquote
14
  from pathlib import Path
15
+ import tempfile
16
+ from tqdm import tqdm
17
 
18
  # ---------------------- UTILITY FUNCTIONS ----------------------
19
 
20
+ def is_valid_url(url):
21
+ """Checks if a string is a valid URL."""
22
+ try:
23
+ result = urlparse(url)
24
+ return all([result.scheme, result.netloc])
25
+ except:
26
+ return False
27
+
28
+ def get_filename(url):
29
+ response = requests.get(url, stream=True)
30
+ response.raise_for_status()
31
+
32
+ if 'content-disposition' in response.headers:
33
+ content_disposition = response.headers['content-disposition']
34
+ filename = re.findall('filename="?([^"]+)"?', content_disposition)[0]
35
  else:
36
+ url_path = urlparse(url).path
37
+ filename = unquote(os.path.basename(url_path))
38
+
39
+ return filename
40
+
41
+ def get_supported_extensions():
42
+ return tuple([".ckpt", ".safetensors", ".pt", ".pth"])
43
+
44
+ def download_model(url, dst, output_widget):
45
+ filename = get_filename(url)
46
+ filepath = os.path.join(dst, filename)
47
+ try:
48
+ if "drive.google.com" in url:
49
+ gdown = gdown_download(url, dst, filepath)
50
+ else:
51
+ if "huggingface.co" in url:
52
+ if "/blob/" in url:
53
+ url = url.replace("/blob/", "/resolve/")
54
+ subprocess.run(["aria2c","-x 16",url,"-d",dst,"-o",filename])
55
+ with output_widget:
56
+ return filepath
57
+ except Exception as e:
58
+ with output_widget:
59
+ return None
60
 
61
  def determine_load_checkpoint(model_to_load):
62
+ """Determines if the model to load is a checkpoint, Diffusers model, or URL."""
63
+ if is_valid_url(model_to_load) and (model_to_load.endswith(get_supported_extensions())):
64
+ return True
65
+ elif model_to_load.endswith(get_supported_extensions()):
66
  return True
67
  elif os.path.isdir(model_to_load):
68
  required_folders = {"unet", "text_encoder", "text_encoder_2", "tokenizer", "tokenizer_2", "scheduler", "vae"}
 
70
  return False
71
  return None # handle this case as required
72
 
 
 
 
 
 
 
 
 
 
73
  def create_model_repo(api, user, orgs_name, model_name, make_private=False):
74
  """Creates a Hugging Face model repository if it doesn't exist."""
75
  if orgs_name == "":
 
93
  required_folders = {"unet", "text_encoder", "text_encoder_2", "tokenizer", "tokenizer_2", "scheduler", "vae"}
94
  return required_folders.issubset(set(os.listdir(model_path))) and os.path.isfile(os.path.join(model_path, "model_index.json"))
95
 
96
+ # ---------------------- MODEL UTIL (From library.sdxl_model_util) ----------------------
97
+ def load_models_from_sdxl_checkpoint(sdxl_base_id, checkpoint_path, device):
98
+ """Loads SDXL model components from a checkpoint file."""
99
+ text_encoder1 = CLIPTextModel.from_pretrained(sdxl_base_id, subfolder="text_encoder").to(device)
100
+ text_encoder2 = CLIPTextModel.from_pretrained(sdxl_base_id, subfolder="text_encoder_2").to(device)
101
+ vae = AutoencoderKL.from_pretrained(sdxl_base_id, subfolder="vae").to(device)
102
+ unet = UNet2DConditionModel.from_pretrained(sdxl_base_id, subfolder="unet").to(device)
103
+ unet = unet
104
+
105
+ ckpt_state_dict = torch.load(checkpoint_path, map_location=device)
106
+
107
+ o = OrderedDict()
108
+ for key in list(ckpt_state_dict.keys()):
109
+ o[key.replace("module.", "")] = ckpt_state_dict[key]
110
+ del ckpt_state_dict
111
+
112
+ print("Applying weights to text encoder 1:")
113
+ text_encoder1.load_state_dict({
114
+ '.'.join(key.split('.')[1:]): o[key] for key in list(o.keys()) if key.startswith("first_stage_model.cond_stage_model.model.transformer")
115
+ }, strict=False)
116
+ print("Applying weights to text encoder 2:")
117
+ text_encoder2.load_state_dict({
118
+ '.'.join(key.split('.')[1:]): o[key] for key in list(o.keys()) if key.startswith("cond_stage_model.model.transformer")
119
+ }, strict=False)
120
+ print("Applying weights to VAE:")
121
+ vae.load_state_dict({
122
+ '.'.join(key.split('.')[2:]): o[key] for key in list(o.keys()) if key.startswith("first_stage_model.model")
123
+ }, strict=False)
124
+ print("Applying weights to UNet:")
125
+ unet.load_state_dict({
126
+ key: o[key] for key in list(o.keys()) if key.startswith("model.diffusion_model")
127
+ }, strict=False)
128
+
129
+ logit_scale = None #Not used here!
130
+ global_step = None #Not used here!
131
+ return text_encoder1, text_encoder2, vae, unet, logit_scale, global_step
132
+
133
+ def save_stable_diffusion_checkpoint(save_path, text_encoder1, text_encoder2, unet, epoch, global_step, ckpt_info, vae, logit_scale, save_dtype):
134
+ """Saves the stable diffusion checkpoint."""
135
+ weights = OrderedDict()
136
+ text_encoder1_dict = text_encoder1.state_dict()
137
+ text_encoder2_dict = text_encoder2.state_dict()
138
+ unet_dict = unet.state_dict()
139
+ vae_dict = vae.state_dict()
140
+
141
+ def replace_key(key):
142
+ key = "cond_stage_model.model.transformer." + key
143
+ return key
144
+
145
+ print("Merging text encoder 1")
146
+ for key in tqdm(list(text_encoder1_dict.keys())):
147
+ weights["first_stage_model.cond_stage_model.model.transformer." + key] = text_encoder1_dict[key].to(save_dtype)
148
+
149
+ print("Merging text encoder 2")
150
+ for key in tqdm(list(text_encoder2_dict.keys())):
151
+ weights[replace_key(key)] = text_encoder2_dict[key].to(save_dtype)
152
+
153
+ print("Merging vae")
154
+ for key in tqdm(list(vae_dict.keys())):
155
+ weights["first_stage_model.model." + key] = vae_dict[key].to(save_dtype)
156
+
157
+ print("Merging unet")
158
+ for key in tqdm(list(unet_dict.keys())):
159
+ weights["model.diffusion_model." + key] = unet_dict[key].to(save_dtype)
160
+
161
+ info = {"epoch": epoch, "global_step": global_step}
162
+ if ckpt_info is not None:
163
+ info.update(ckpt_info)
164
+
165
+ if logit_scale is not None:
166
+ info["logit_scale"] = logit_scale.item()
167
+
168
+ torch.save({"state_dict": weights, "info": info}, save_path)
169
+
170
+ key_count = len(weights.keys())
171
+ del weights
172
+ del text_encoder1_dict, text_encoder2_dict, unet_dict, vae_dict
173
+ return key_count
174
+
175
+ def save_diffusers_checkpoint(save_path, text_encoder1, text_encoder2, unet, reference_model, vae, trim_if_model_exists, save_dtype):
176
+ """Saves Diffusers-style checkpoint from the model."""
177
+ print("Saving SDXL as Diffusers format to:", save_path)
178
+ print("SDXL Text Encoder 1 to:", os.path.join(save_path, "text_encoder"))
179
+ text_encoder1.save_pretrained(os.path.join(save_path, "text_encoder"))
180
+
181
+ print("SDXL Text Encoder 2 to:", os.path.join(save_path, "text_encoder_2"))
182
+ text_encoder2.save_pretrained(os.path.join(save_path, "text_encoder_2"))
183
+
184
+ print("SDXL VAE to:", os.path.join(save_path, "vae"))
185
+ vae.save_pretrained(os.path.join(save_path, "vae"))
186
+
187
+ print("SDXL UNet to:", os.path.join(save_path, "unet"))
188
+ unet.save_pretrained(os.path.join(save_path, "unet"))
189
+
190
+ if reference_model is not None:
191
+ print(f"Copying scheduler from {reference_model}")
192
+ scheduler_src = StableDiffusionXLPipeline.from_pretrained(reference_model, torch_dtype=torch.float16).scheduler
193
+ torch.save(scheduler_src.config, os.path.join(save_path, "scheduler", "scheduler_config.json"))
194
+ else:
195
+ print(f"No reference Model. Copying scheduler from original model.")
196
+ scheduler_src = StableDiffusionXLPipeline.from_pretrained(reference_model, torch_dtype=torch.float16).scheduler
197
+ scheduler_src.save_pretrained(save_path)
198
+
199
+ if trim_if_model_exists:
200
+ print("Trim Complete")
201
+
202
  # ---------------------- CONVERSION AND UPLOAD FUNCTIONS ----------------------
203
 
204
  def load_sdxl_model(args, is_load_checkpoint, load_dtype, output_widget):
 
216
 
217
  def load_from_sdxl_checkpoint(args, output_widget):
218
  """Loads the SDXL model components from a checkpoint file (placeholder)."""
 
 
 
 
 
219
  text_encoder1, text_encoder2, vae, unet = None, None, None, None
220
+ device = "cpu"
221
+ if is_valid_url(args.model_to_load):
222
+ tmp_model_name = "download"
223
+ download_dst_dir = tempfile.mkdtemp()
224
+ model_path = download_model(args.model_to_load, download_dst_dir, output_widget)
225
+ #model_path = os.path.join(download_dst_dir,tmp_model_name)
226
+ if model_path == None:
227
+ with output_widget:
228
+ print("Loading from Checkpoint failed, the request could not be completed")
229
+ return text_encoder1, text_encoder2, vae, unet
230
+ else:
231
+ # Implement Load model from ckpt or safetensors
232
+ try:
233
+ text_encoder1, text_encoder2, vae, unet, _, _ = load_models_from_sdxl_checkpoint(
234
+ "sdxl_base_v1-0", model_path, device
235
+ )
236
+ return text_encoder1, text_encoder2, vae, unet
237
+ except Exception as e:
238
+ print(f"Could not load SDXL from checkpoint due to: \n{e}")
239
+ return text_encoder1, text_encoder2, vae, unet
240
 
241
+ with output_widget:
242
+ print(f"Loading from Checkpoint from URL needs to be implemented - using {model_path}")
243
+ else:
244
+ # Implement Load model from ckpt or safetensors
245
+ try:
246
+ text_encoder1, text_encoder2, vae, unet, _, _ = load_models_from_sdxl_checkpoint(
247
+ "sdxl_base_v1-0", args.model_to_load, device
248
+ )
249
+ return text_encoder1, text_encoder2, vae, unet
250
+ except Exception as e:
251
+ print(f"Could not load SDXL from checkpoint due to: \n{e}")
252
+ return text_encoder1, text_encoder2, vae, unet
253
+
254
+ with output_widget:
255
+ print("Loading from Checkpoint needs to be implemented.")
256
 
257
  return text_encoder1, text_encoder2, vae, unet
258
 
 
283
 
284
  def save_sdxl_as_checkpoint(args, text_encoder1, text_encoder2, vae, unet, save_dtype, output_widget):
285
  """Saves the SDXL model components as a checkpoint file (placeholder)."""
286
+ logit_scale = None
287
+ ckpt_info = None
 
 
 
 
288
 
289
+ key_count = save_stable_diffusion_checkpoint(
290
+ args.model_to_save, text_encoder1, text_encoder2, unet, args.epoch, args.global_step, ckpt_info, vae, logit_scale, save_dtype
291
+ )
292
  with output_widget:
293
+ print(f"Model saved. Total converted state_dict keys: {key_count}")
 
294
 
295
  def save_sdxl_as_diffusers(args, text_encoder1, text_encoder2, vae, unet, save_dtype, output_widget):
296
  """Saves the SDXL model as a Diffusers model."""
 
326
  self.output_path = output_path #Using output_path even if hardcoded
327
  self.fp16 = fp16
328
 
 
329
  with tempfile.TemporaryDirectory() as tmpdirname:
330
  args = Args(model_to_load, save_precision_as, epoch, global_step, reference_model, tmpdirname, fp16)
331
  args.model_to_save = increment_filename(os.path.splitext(args.model_to_load)[0] + ".safetensors")
 
401
  """Main function orchestrating the entire process."""
402
  output = gr.Markdown()
403
 
 
 
404
  # Create tempdir, will only be there for the function
405
  with tempfile.TemporaryDirectory() as output_path:
406
  conversion_output = convert_model(model_to_load, save_precision_as, epoch, global_step, reference_model, fp16, output)
 
416
  gr.Markdown(f"""
417
  ## **⚠️ IMPORTANT WARNINGS ⚠️**
418
  This App is Coded by an LLM partially, and for more information please go here: [Ktiseos Nyx](https://github.com/Ktiseos-Nyx/Sdxl-to-diffusers). The colab edition of this may indeed break AUP. This space is running on CPU and in theory SHOULD work, but may be slow. Earth and Dusk/ Ktiseos Nyx does not have the enterprise budget for ZERO GPU or any gpu sadly! Thank you to the community, John6666 especially for coming to aid when gemini would NOT fix the requirements. Support Ktiseos Nyx & Myself on Ko-fi: [![Ko-fi](https://img.shields.io/badge/Support%20me%20on%20Ko--fi-F16061?logo=ko-fi&logoColor=white&style=flat)](https://ko-fi.com/Z8Z8L4EO)
419
+
420
+ **Understanding the 'Model to Load' Input:**
421
+
422
+ This field can accept any of the following:
423
+ * A Hugging Face model identifier (e.g., `stabilityai/stable-diffusion-xl-base-1.0`).
424
+ * A direct URL to a .ckpt or .safetensors model file.
425
+ * **Important:** Huggingface direct links need to end as /resolve/main/ and the name of the model after.
426
  """)
427
 
428
  model_to_load = gr.Textbox(label="Model to Load (Checkpoint or Diffusers)", placeholder="Path to model")
 
437
 
438
  reference_model = gr.Textbox(label="Reference Diffusers Model",
439
  placeholder="e.g., stabilityai/stable-diffusion-xl-base-1.0")
 
440
 
441
  gr.Markdown("## Hugging Face Hub Configuration")
442
+ hf_token = gr.Textbox(type="password", label="Hugging Face Token", placeholder="Your Hugging Face write token") #THIS IS NEEDED
443
  with gr.Row():
444
  orgs_name = gr.Textbox(label="Organization Name (Optional)", placeholder="Your organization name")
445
  model_name = gr.Textbox(label="Model Name", placeholder="The name of your model on Hugging Face")