Upload 3 files
Browse files- blocks.py +0 -0
- extras.py +330 -0
- sd_models.py +840 -0
blocks.py
ADDED
The diff for this file is too large to render.
See raw diff
|
|
extras.py
ADDED
@@ -0,0 +1,330 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import re
|
3 |
+
import shutil
|
4 |
+
import json
|
5 |
+
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import tqdm
|
9 |
+
|
10 |
+
from modules import shared, images, sd_models, sd_vae, sd_models_config, errors
|
11 |
+
from modules.ui_common import plaintext_to_html
|
12 |
+
import gradio as gr
|
13 |
+
import safetensors.torch
|
14 |
+
|
15 |
+
|
16 |
+
def run_pnginfo(image):
|
17 |
+
if image is None:
|
18 |
+
return '', '', ''
|
19 |
+
|
20 |
+
geninfo, items = images.read_info_from_image(image)
|
21 |
+
items = {**{'parameters': geninfo}, **items}
|
22 |
+
|
23 |
+
info = ''
|
24 |
+
for key, text in items.items():
|
25 |
+
info += f"""
|
26 |
+
<div>
|
27 |
+
<p><b>{plaintext_to_html(str(key))}</b></p>
|
28 |
+
<p>{plaintext_to_html(str(text))}</p>
|
29 |
+
</div>
|
30 |
+
""".strip()+"\n"
|
31 |
+
|
32 |
+
if len(info) == 0:
|
33 |
+
message = "Nothing found in the image."
|
34 |
+
info = f"<div><p>{message}<p></div>"
|
35 |
+
|
36 |
+
return '', geninfo, info
|
37 |
+
|
38 |
+
|
39 |
+
def create_config(ckpt_result, config_source, a, b, c):
|
40 |
+
def config(x):
|
41 |
+
res = sd_models_config.find_checkpoint_config_near_filename(x) if x else None
|
42 |
+
return res if res != shared.sd_default_config else None
|
43 |
+
|
44 |
+
if config_source == 0:
|
45 |
+
cfg = config(a) or config(b) or config(c)
|
46 |
+
elif config_source == 1:
|
47 |
+
cfg = config(b)
|
48 |
+
elif config_source == 2:
|
49 |
+
cfg = config(c)
|
50 |
+
else:
|
51 |
+
cfg = None
|
52 |
+
|
53 |
+
if cfg is None:
|
54 |
+
return
|
55 |
+
|
56 |
+
filename, _ = os.path.splitext(ckpt_result)
|
57 |
+
checkpoint_filename = filename + ".yaml"
|
58 |
+
|
59 |
+
print("Copying config:")
|
60 |
+
print(" from:", cfg)
|
61 |
+
print(" to:", checkpoint_filename)
|
62 |
+
shutil.copyfile(cfg, checkpoint_filename)
|
63 |
+
|
64 |
+
|
65 |
+
checkpoint_dict_skip_on_merge = ["cond_stage_model.transformer.text_model.embeddings.position_ids"]
|
66 |
+
|
67 |
+
|
68 |
+
def to_half(tensor, enable):
|
69 |
+
if enable and tensor.dtype == torch.float:
|
70 |
+
return tensor.half()
|
71 |
+
|
72 |
+
return tensor
|
73 |
+
|
74 |
+
|
75 |
+
def read_metadata(primary_model_name, secondary_model_name, tertiary_model_name):
|
76 |
+
metadata = {}
|
77 |
+
|
78 |
+
for checkpoint_name in [primary_model_name, secondary_model_name, tertiary_model_name]:
|
79 |
+
checkpoint_info = sd_models.checkpoints_list.get(checkpoint_name, None)
|
80 |
+
if checkpoint_info is None:
|
81 |
+
continue
|
82 |
+
|
83 |
+
metadata.update(checkpoint_info.metadata)
|
84 |
+
|
85 |
+
return json.dumps(metadata, indent=4, ensure_ascii=False)
|
86 |
+
|
87 |
+
|
88 |
+
def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae, discard_weights, save_metadata, add_merge_recipe, copy_metadata_fields, metadata_json):
|
89 |
+
shared.state.begin(job="model-merge")
|
90 |
+
|
91 |
+
def fail(message):
|
92 |
+
shared.state.textinfo = message
|
93 |
+
shared.state.end()
|
94 |
+
return [*[gr.update() for _ in range(4)], message]
|
95 |
+
|
96 |
+
def weighted_sum(theta0, theta1, alpha):
|
97 |
+
return ((1 - alpha) * theta0) + (alpha * theta1)
|
98 |
+
|
99 |
+
def get_difference(theta1, theta2):
|
100 |
+
return theta1 - theta2
|
101 |
+
|
102 |
+
def add_difference(theta0, theta1_2_diff, alpha):
|
103 |
+
return theta0 + (alpha * theta1_2_diff)
|
104 |
+
|
105 |
+
def filename_weighted_sum():
|
106 |
+
a = primary_model_info.model_name
|
107 |
+
b = secondary_model_info.model_name
|
108 |
+
Ma = round(1 - multiplier, 2)
|
109 |
+
Mb = round(multiplier, 2)
|
110 |
+
|
111 |
+
return f"{Ma}({a}) + {Mb}({b})"
|
112 |
+
|
113 |
+
def filename_add_difference():
|
114 |
+
a = primary_model_info.model_name
|
115 |
+
b = secondary_model_info.model_name
|
116 |
+
c = tertiary_model_info.model_name
|
117 |
+
M = round(multiplier, 2)
|
118 |
+
|
119 |
+
return f"{a} + {M}({b} - {c})"
|
120 |
+
|
121 |
+
def filename_nothing():
|
122 |
+
return primary_model_info.model_name
|
123 |
+
|
124 |
+
theta_funcs = {
|
125 |
+
"Weighted sum": (filename_weighted_sum, None, weighted_sum),
|
126 |
+
"Add difference": (filename_add_difference, get_difference, add_difference),
|
127 |
+
"No interpolation": (filename_nothing, None, None),
|
128 |
+
}
|
129 |
+
filename_generator, theta_func1, theta_func2 = theta_funcs[interp_method]
|
130 |
+
shared.state.job_count = (1 if theta_func1 else 0) + (1 if theta_func2 else 0)
|
131 |
+
|
132 |
+
if not primary_model_name:
|
133 |
+
return fail("Failed: Merging requires a primary model.")
|
134 |
+
|
135 |
+
primary_model_info = sd_models.checkpoints_list[primary_model_name]
|
136 |
+
|
137 |
+
if theta_func2 and not secondary_model_name:
|
138 |
+
return fail("Failed: Merging requires a secondary model.")
|
139 |
+
|
140 |
+
secondary_model_info = sd_models.checkpoints_list[secondary_model_name] if theta_func2 else None
|
141 |
+
|
142 |
+
if theta_func1 and not tertiary_model_name:
|
143 |
+
return fail(f"Failed: Interpolation method ({interp_method}) requires a tertiary model.")
|
144 |
+
|
145 |
+
tertiary_model_info = sd_models.checkpoints_list[tertiary_model_name] if theta_func1 else None
|
146 |
+
|
147 |
+
result_is_inpainting_model = False
|
148 |
+
result_is_instruct_pix2pix_model = False
|
149 |
+
|
150 |
+
if theta_func2:
|
151 |
+
shared.state.textinfo = "Loading B"
|
152 |
+
print(f"Loading {secondary_model_info.filename}...")
|
153 |
+
theta_1 = sd_models.read_state_dict(secondary_model_info.filename, map_location='cpu')
|
154 |
+
else:
|
155 |
+
theta_1 = None
|
156 |
+
|
157 |
+
if theta_func1:
|
158 |
+
shared.state.textinfo = "Loading C"
|
159 |
+
print(f"Loading {tertiary_model_info.filename}...")
|
160 |
+
theta_2 = sd_models.read_state_dict(tertiary_model_info.filename, map_location='cpu')
|
161 |
+
|
162 |
+
shared.state.textinfo = 'Merging B and C'
|
163 |
+
shared.state.sampling_steps = len(theta_1.keys())
|
164 |
+
for key in tqdm.tqdm(theta_1.keys()):
|
165 |
+
if key in checkpoint_dict_skip_on_merge:
|
166 |
+
continue
|
167 |
+
|
168 |
+
if 'model' in key:
|
169 |
+
if key in theta_2:
|
170 |
+
t2 = theta_2.get(key, torch.zeros_like(theta_1[key]))
|
171 |
+
theta_1[key] = theta_func1(theta_1[key], t2)
|
172 |
+
else:
|
173 |
+
theta_1[key] = torch.zeros_like(theta_1[key])
|
174 |
+
|
175 |
+
shared.state.sampling_step += 1
|
176 |
+
del theta_2
|
177 |
+
|
178 |
+
shared.state.nextjob()
|
179 |
+
|
180 |
+
shared.state.textinfo = f"Loading {primary_model_info.filename}..."
|
181 |
+
print(f"Loading {primary_model_info.filename}...")
|
182 |
+
theta_0 = sd_models.read_state_dict(primary_model_info.filename, map_location='cpu')
|
183 |
+
|
184 |
+
print("Merging...")
|
185 |
+
shared.state.textinfo = 'Merging A and B'
|
186 |
+
shared.state.sampling_steps = len(theta_0.keys())
|
187 |
+
for key in tqdm.tqdm(theta_0.keys()):
|
188 |
+
if theta_1 and 'model' in key and key in theta_1:
|
189 |
+
|
190 |
+
if key in checkpoint_dict_skip_on_merge:
|
191 |
+
continue
|
192 |
+
|
193 |
+
a = theta_0[key]
|
194 |
+
b = theta_1[key]
|
195 |
+
|
196 |
+
# this enables merging an inpainting model (A) with another one (B);
|
197 |
+
# where normal model would have 4 channels, for latenst space, inpainting model would
|
198 |
+
# have another 4 channels for unmasked picture's latent space, plus one channel for mask, for a total of 9
|
199 |
+
if a.shape != b.shape and a.shape[0:1] + a.shape[2:] == b.shape[0:1] + b.shape[2:]:
|
200 |
+
if a.shape[1] == 4 and b.shape[1] == 9:
|
201 |
+
raise RuntimeError("When merging inpainting model with a normal one, A must be the inpainting model.")
|
202 |
+
if a.shape[1] == 4 and b.shape[1] == 8:
|
203 |
+
raise RuntimeError("When merging instruct-pix2pix model with a normal one, A must be the instruct-pix2pix model.")
|
204 |
+
|
205 |
+
if a.shape[1] == 8 and b.shape[1] == 4:#If we have an Instruct-Pix2Pix model...
|
206 |
+
theta_0[key][:, 0:4, :, :] = theta_func2(a[:, 0:4, :, :], b, multiplier)#Merge only the vectors the models have in common. Otherwise we get an error due to dimension mismatch.
|
207 |
+
result_is_instruct_pix2pix_model = True
|
208 |
+
else:
|
209 |
+
assert a.shape[1] == 9 and b.shape[1] == 4, f"Bad dimensions for merged layer {key}: A={a.shape}, B={b.shape}"
|
210 |
+
theta_0[key][:, 0:4, :, :] = theta_func2(a[:, 0:4, :, :], b, multiplier)
|
211 |
+
result_is_inpainting_model = True
|
212 |
+
else:
|
213 |
+
theta_0[key] = theta_func2(a, b, multiplier)
|
214 |
+
|
215 |
+
theta_0[key] = to_half(theta_0[key], save_as_half)
|
216 |
+
|
217 |
+
shared.state.sampling_step += 1
|
218 |
+
|
219 |
+
del theta_1
|
220 |
+
|
221 |
+
bake_in_vae_filename = sd_vae.vae_dict.get(bake_in_vae, None)
|
222 |
+
if bake_in_vae_filename is not None:
|
223 |
+
print(f"Baking in VAE from {bake_in_vae_filename}")
|
224 |
+
shared.state.textinfo = 'Baking in VAE'
|
225 |
+
vae_dict = sd_vae.load_vae_dict(bake_in_vae_filename, map_location='cpu')
|
226 |
+
|
227 |
+
for key in vae_dict.keys():
|
228 |
+
theta_0_key = 'first_stage_model.' + key
|
229 |
+
if theta_0_key in theta_0:
|
230 |
+
theta_0[theta_0_key] = to_half(vae_dict[key], save_as_half)
|
231 |
+
|
232 |
+
del vae_dict
|
233 |
+
|
234 |
+
if save_as_half and not theta_func2:
|
235 |
+
for key in theta_0.keys():
|
236 |
+
theta_0[key] = to_half(theta_0[key], save_as_half)
|
237 |
+
|
238 |
+
if discard_weights:
|
239 |
+
regex = re.compile(discard_weights)
|
240 |
+
for key in list(theta_0):
|
241 |
+
if re.search(regex, key):
|
242 |
+
theta_0.pop(key, None)
|
243 |
+
|
244 |
+
ckpt_dir = shared.cmd_opts.ckpt_dir or sd_models.model_path
|
245 |
+
|
246 |
+
filename = filename_generator() if custom_name == '' else custom_name
|
247 |
+
filename += ".inpainting" if result_is_inpainting_model else ""
|
248 |
+
filename += ".instruct-pix2pix" if result_is_instruct_pix2pix_model else ""
|
249 |
+
filename += "." + checkpoint_format
|
250 |
+
|
251 |
+
output_modelname = os.path.join(ckpt_dir, filename)
|
252 |
+
|
253 |
+
shared.state.nextjob()
|
254 |
+
shared.state.textinfo = "Saving"
|
255 |
+
print(f"Saving to {output_modelname}...")
|
256 |
+
|
257 |
+
metadata = {}
|
258 |
+
|
259 |
+
if save_metadata and copy_metadata_fields:
|
260 |
+
if primary_model_info:
|
261 |
+
metadata.update(primary_model_info.metadata)
|
262 |
+
if secondary_model_info:
|
263 |
+
metadata.update(secondary_model_info.metadata)
|
264 |
+
if tertiary_model_info:
|
265 |
+
metadata.update(tertiary_model_info.metadata)
|
266 |
+
|
267 |
+
if save_metadata:
|
268 |
+
try:
|
269 |
+
metadata.update(json.loads(metadata_json))
|
270 |
+
except Exception as e:
|
271 |
+
errors.display(e, "readin metadata from json")
|
272 |
+
|
273 |
+
metadata["format"] = "pt"
|
274 |
+
|
275 |
+
if save_metadata and add_merge_recipe:
|
276 |
+
merge_recipe = {
|
277 |
+
"type": "webui", # indicate this model was merged with webui's built-in merger
|
278 |
+
"primary_model_hash": primary_model_info.sha256,
|
279 |
+
"secondary_model_hash": secondary_model_info.sha256 if secondary_model_info else None,
|
280 |
+
"tertiary_model_hash": tertiary_model_info.sha256 if tertiary_model_info else None,
|
281 |
+
"interp_method": interp_method,
|
282 |
+
"multiplier": multiplier,
|
283 |
+
"save_as_half": save_as_half,
|
284 |
+
"custom_name": custom_name,
|
285 |
+
"config_source": config_source,
|
286 |
+
"bake_in_vae": bake_in_vae,
|
287 |
+
"discard_weights": discard_weights,
|
288 |
+
"is_inpainting": result_is_inpainting_model,
|
289 |
+
"is_instruct_pix2pix": result_is_instruct_pix2pix_model
|
290 |
+
}
|
291 |
+
|
292 |
+
sd_merge_models = {}
|
293 |
+
|
294 |
+
def add_model_metadata(checkpoint_info):
|
295 |
+
checkpoint_info.calculate_shorthash()
|
296 |
+
sd_merge_models[checkpoint_info.sha256] = {
|
297 |
+
"name": checkpoint_info.name,
|
298 |
+
"legacy_hash": checkpoint_info.hash,
|
299 |
+
"sd_merge_recipe": checkpoint_info.metadata.get("sd_merge_recipe", None)
|
300 |
+
}
|
301 |
+
|
302 |
+
sd_merge_models.update(checkpoint_info.metadata.get("sd_merge_models", {}))
|
303 |
+
|
304 |
+
add_model_metadata(primary_model_info)
|
305 |
+
if secondary_model_info:
|
306 |
+
add_model_metadata(secondary_model_info)
|
307 |
+
if tertiary_model_info:
|
308 |
+
add_model_metadata(tertiary_model_info)
|
309 |
+
|
310 |
+
metadata["sd_merge_recipe"] = json.dumps(merge_recipe)
|
311 |
+
metadata["sd_merge_models"] = json.dumps(sd_merge_models)
|
312 |
+
|
313 |
+
_, extension = os.path.splitext(output_modelname)
|
314 |
+
if extension.lower() == ".safetensors":
|
315 |
+
safetensors.torch.save_file(theta_0, output_modelname, metadata=metadata if len(metadata)>0 else None)
|
316 |
+
else:
|
317 |
+
torch.save(theta_0, output_modelname)
|
318 |
+
|
319 |
+
sd_models.list_models()
|
320 |
+
created_model = next((ckpt for ckpt in sd_models.checkpoints_list.values() if ckpt.name == filename), None)
|
321 |
+
if created_model:
|
322 |
+
created_model.calculate_shorthash()
|
323 |
+
|
324 |
+
create_config(output_modelname, config_source, primary_model_info, secondary_model_info, tertiary_model_info)
|
325 |
+
|
326 |
+
print(f"Checkpoint saved to {output_modelname}.")
|
327 |
+
shared.state.textinfo = "Checkpoint saved"
|
328 |
+
shared.state.end()
|
329 |
+
|
330 |
+
return [*[gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)], "Checkpoint saved to " + output_modelname]
|
sd_models.py
ADDED
@@ -0,0 +1,840 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import collections
|
2 |
+
import os.path
|
3 |
+
import sys
|
4 |
+
import threading
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import re
|
8 |
+
import safetensors.torch
|
9 |
+
from omegaconf import OmegaConf, ListConfig
|
10 |
+
from os import mkdir
|
11 |
+
from urllib import request
|
12 |
+
import ldm.modules.midas as midas
|
13 |
+
|
14 |
+
from ldm.util import instantiate_from_config
|
15 |
+
|
16 |
+
from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache, extra_networks, processing, lowvram, sd_hijack, patches
|
17 |
+
from modules.timer import Timer
|
18 |
+
import tomesd
|
19 |
+
import numpy as np
|
20 |
+
|
21 |
+
model_dir = "Stable-diffusion"
|
22 |
+
model_path = os.path.abspath(os.path.join(paths.models_path, model_dir))
|
23 |
+
|
24 |
+
checkpoints_list = {}
|
25 |
+
checkpoint_aliases = {}
|
26 |
+
checkpoint_alisases = checkpoint_aliases # for compatibility with old name
|
27 |
+
checkpoints_loaded = collections.OrderedDict()
|
28 |
+
|
29 |
+
|
30 |
+
def replace_key(d, key, new_key, value):
|
31 |
+
keys = list(d.keys())
|
32 |
+
|
33 |
+
d[new_key] = value
|
34 |
+
|
35 |
+
if key not in keys:
|
36 |
+
return d
|
37 |
+
|
38 |
+
index = keys.index(key)
|
39 |
+
keys[index] = new_key
|
40 |
+
|
41 |
+
new_d = {k: d[k] for k in keys}
|
42 |
+
|
43 |
+
d.clear()
|
44 |
+
d.update(new_d)
|
45 |
+
return d
|
46 |
+
|
47 |
+
|
48 |
+
class CheckpointInfo:
|
49 |
+
def __init__(self, filename):
|
50 |
+
self.filename = filename
|
51 |
+
abspath = os.path.abspath(filename)
|
52 |
+
abs_ckpt_dir = os.path.abspath(shared.cmd_opts.ckpt_dir) if shared.cmd_opts.ckpt_dir is not None else None
|
53 |
+
|
54 |
+
self.is_safetensors = os.path.splitext(filename)[1].lower() == ".safetensors"
|
55 |
+
|
56 |
+
if abs_ckpt_dir and abspath.startswith(abs_ckpt_dir):
|
57 |
+
name = abspath.replace(abs_ckpt_dir, '')
|
58 |
+
elif abspath.startswith(model_path):
|
59 |
+
name = abspath.replace(model_path, '')
|
60 |
+
else:
|
61 |
+
name = os.path.basename(filename)
|
62 |
+
|
63 |
+
if name.startswith("\\") or name.startswith("/"):
|
64 |
+
name = name[1:]
|
65 |
+
|
66 |
+
def read_metadata():
|
67 |
+
metadata = read_metadata_from_safetensors(filename)
|
68 |
+
self.modelspec_thumbnail = metadata.pop('modelspec.thumbnail', None)
|
69 |
+
|
70 |
+
return metadata
|
71 |
+
|
72 |
+
self.metadata = {}
|
73 |
+
if self.is_safetensors:
|
74 |
+
try:
|
75 |
+
self.metadata = cache.cached_data_for_file('safetensors-metadata', "checkpoint/" + name, filename, read_metadata)
|
76 |
+
except Exception as e:
|
77 |
+
errors.display(e, f"reading metadata for {filename}")
|
78 |
+
|
79 |
+
self.name = name
|
80 |
+
self.name_for_extra = os.path.splitext(os.path.basename(filename))[0]
|
81 |
+
self.model_name = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0]
|
82 |
+
self.hash = model_hash(filename)
|
83 |
+
|
84 |
+
self.sha256 = hashes.sha256_from_cache(self.filename, f"checkpoint/{name}")
|
85 |
+
self.shorthash = self.sha256[0:10] if self.sha256 else None
|
86 |
+
|
87 |
+
self.title = name if self.shorthash is None else f'{name} [{self.shorthash}]'
|
88 |
+
self.short_title = self.name_for_extra if self.shorthash is None else f'{self.name_for_extra} [{self.shorthash}]'
|
89 |
+
|
90 |
+
self.ids = [self.hash, self.model_name, self.title, name, self.name_for_extra, f'{name} [{self.hash}]']
|
91 |
+
if self.shorthash:
|
92 |
+
self.ids += [self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]', f'{self.name_for_extra} [{self.shorthash}]']
|
93 |
+
|
94 |
+
def register(self):
|
95 |
+
checkpoints_list[self.title] = self
|
96 |
+
for id in self.ids:
|
97 |
+
checkpoint_aliases[id] = self
|
98 |
+
|
99 |
+
def calculate_shorthash(self):
|
100 |
+
self.sha256 = hashes.sha256(self.filename, f"checkpoint/{self.name}")
|
101 |
+
if self.sha256 is None:
|
102 |
+
return
|
103 |
+
|
104 |
+
shorthash = self.sha256[0:10]
|
105 |
+
if self.shorthash == self.sha256[0:10]:
|
106 |
+
return self.shorthash
|
107 |
+
|
108 |
+
self.shorthash = shorthash
|
109 |
+
|
110 |
+
if self.shorthash not in self.ids:
|
111 |
+
self.ids += [self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]', f'{self.name_for_extra} [{self.shorthash}]']
|
112 |
+
|
113 |
+
old_title = self.title
|
114 |
+
self.title = f'{self.name} [{self.shorthash}]'
|
115 |
+
self.short_title = f'{self.name_for_extra} [{self.shorthash}]'
|
116 |
+
|
117 |
+
replace_key(checkpoints_list, old_title, self.title, self)
|
118 |
+
self.register()
|
119 |
+
|
120 |
+
return self.shorthash
|
121 |
+
|
122 |
+
|
123 |
+
try:
|
124 |
+
# this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
|
125 |
+
from transformers import logging, CLIPModel # noqa: F401
|
126 |
+
|
127 |
+
logging.set_verbosity_error()
|
128 |
+
except Exception:
|
129 |
+
pass
|
130 |
+
|
131 |
+
|
132 |
+
def setup_model():
|
133 |
+
"""called once at startup to do various one-time tasks related to SD models"""
|
134 |
+
|
135 |
+
os.makedirs(model_path, exist_ok=True)
|
136 |
+
|
137 |
+
enable_midas_autodownload()
|
138 |
+
patch_given_betas()
|
139 |
+
|
140 |
+
|
141 |
+
def checkpoint_tiles(use_short=False):
|
142 |
+
return [x.short_title if use_short else x.title for x in checkpoints_list.values()]
|
143 |
+
|
144 |
+
|
145 |
+
def list_models():
|
146 |
+
checkpoints_list.clear()
|
147 |
+
checkpoint_aliases.clear()
|
148 |
+
|
149 |
+
cmd_ckpt = shared.cmd_opts.ckpt
|
150 |
+
if shared.cmd_opts.no_download_sd_model or cmd_ckpt != shared.sd_model_file or os.path.exists(cmd_ckpt):
|
151 |
+
model_url = None
|
152 |
+
else:
|
153 |
+
model_url = "https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.safetensors"
|
154 |
+
|
155 |
+
model_list = modelloader.load_models(model_path=model_path, model_url=model_url, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt", ".safetensors"], download_name="v1-5-pruned-emaonly.safetensors", ext_blacklist=[".vae.ckpt", ".vae.safetensors"])
|
156 |
+
|
157 |
+
if os.path.exists(cmd_ckpt):
|
158 |
+
checkpoint_info = CheckpointInfo(cmd_ckpt)
|
159 |
+
checkpoint_info.register()
|
160 |
+
|
161 |
+
shared.opts.data['sd_model_checkpoint'] = checkpoint_info.title
|
162 |
+
elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file:
|
163 |
+
print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr)
|
164 |
+
|
165 |
+
for filename in model_list:
|
166 |
+
checkpoint_info = CheckpointInfo(filename)
|
167 |
+
checkpoint_info.register()
|
168 |
+
|
169 |
+
|
170 |
+
re_strip_checksum = re.compile(r"\s*\[[^]]+]\s*$")
|
171 |
+
|
172 |
+
|
173 |
+
def get_closet_checkpoint_match(search_string):
|
174 |
+
if not search_string:
|
175 |
+
return None
|
176 |
+
|
177 |
+
checkpoint_info = checkpoint_aliases.get(search_string, None)
|
178 |
+
if checkpoint_info is not None:
|
179 |
+
return checkpoint_info
|
180 |
+
|
181 |
+
found = sorted([info for info in checkpoints_list.values() if search_string in info.title], key=lambda x: len(x.title))
|
182 |
+
if found:
|
183 |
+
return found[0]
|
184 |
+
|
185 |
+
search_string_without_checksum = re.sub(re_strip_checksum, '', search_string)
|
186 |
+
found = sorted([info for info in checkpoints_list.values() if search_string_without_checksum in info.title], key=lambda x: len(x.title))
|
187 |
+
if found:
|
188 |
+
return found[0]
|
189 |
+
|
190 |
+
return None
|
191 |
+
|
192 |
+
|
193 |
+
def model_hash(filename):
|
194 |
+
"""old hash that only looks at a small part of the file and is prone to collisions"""
|
195 |
+
|
196 |
+
try:
|
197 |
+
with open(filename, "rb") as file:
|
198 |
+
import hashlib
|
199 |
+
m = hashlib.sha256()
|
200 |
+
|
201 |
+
file.seek(0x100000)
|
202 |
+
m.update(file.read(0x10000))
|
203 |
+
return m.hexdigest()[0:8]
|
204 |
+
except FileNotFoundError:
|
205 |
+
return 'NOFILE'
|
206 |
+
|
207 |
+
|
208 |
+
def select_checkpoint():
|
209 |
+
"""Raises `FileNotFoundError` if no checkpoints are found."""
|
210 |
+
model_checkpoint = shared.opts.sd_model_checkpoint
|
211 |
+
|
212 |
+
checkpoint_info = checkpoint_aliases.get(model_checkpoint, None)
|
213 |
+
if checkpoint_info is not None:
|
214 |
+
return checkpoint_info
|
215 |
+
|
216 |
+
if len(checkpoints_list) == 0:
|
217 |
+
error_message = "No checkpoints found. When searching for checkpoints, looked at:"
|
218 |
+
if shared.cmd_opts.ckpt is not None:
|
219 |
+
error_message += f"\n - file {os.path.abspath(shared.cmd_opts.ckpt)}"
|
220 |
+
error_message += f"\n - directory {model_path}"
|
221 |
+
if shared.cmd_opts.ckpt_dir is not None:
|
222 |
+
error_message += f"\n - directory {os.path.abspath(shared.cmd_opts.ckpt_dir)}"
|
223 |
+
error_message += "Can't run without a checkpoint. Find and place a .ckpt or .safetensors file into any of those locations."
|
224 |
+
raise FileNotFoundError(error_message)
|
225 |
+
|
226 |
+
checkpoint_info = next(iter(checkpoints_list.values()))
|
227 |
+
if model_checkpoint is not None:
|
228 |
+
print(f"Checkpoint {model_checkpoint} not found; loading fallback {checkpoint_info.title}", file=sys.stderr)
|
229 |
+
|
230 |
+
return checkpoint_info
|
231 |
+
|
232 |
+
|
233 |
+
checkpoint_dict_replacements_sd1 = {
|
234 |
+
'cond_stage_model.transformer.embeddings.': 'cond_stage_model.transformer.text_model.embeddings.',
|
235 |
+
'cond_stage_model.transformer.encoder.': 'cond_stage_model.transformer.text_model.encoder.',
|
236 |
+
'cond_stage_model.transformer.final_layer_norm.': 'cond_stage_model.transformer.text_model.final_layer_norm.',
|
237 |
+
}
|
238 |
+
|
239 |
+
checkpoint_dict_replacements_sd2_turbo = { # Converts SD 2.1 Turbo from SGM to LDM format.
|
240 |
+
'conditioner.embedders.0.': 'cond_stage_model.',
|
241 |
+
}
|
242 |
+
|
243 |
+
|
244 |
+
def transform_checkpoint_dict_key(k, replacements):
|
245 |
+
for text, replacement in replacements.items():
|
246 |
+
if k.startswith(text):
|
247 |
+
k = replacement + k[len(text):]
|
248 |
+
|
249 |
+
return k
|
250 |
+
|
251 |
+
|
252 |
+
def get_state_dict_from_checkpoint(pl_sd):
|
253 |
+
pl_sd = pl_sd.pop("state_dict", pl_sd)
|
254 |
+
pl_sd.pop("state_dict", None)
|
255 |
+
|
256 |
+
is_sd2_turbo = 'conditioner.embedders.0.model.ln_final.weight' in pl_sd and pl_sd['conditioner.embedders.0.model.ln_final.weight'].size()[0] == 1024
|
257 |
+
|
258 |
+
sd = {}
|
259 |
+
for k, v in pl_sd.items():
|
260 |
+
if is_sd2_turbo:
|
261 |
+
new_key = transform_checkpoint_dict_key(k, checkpoint_dict_replacements_sd2_turbo)
|
262 |
+
else:
|
263 |
+
new_key = transform_checkpoint_dict_key(k, checkpoint_dict_replacements_sd1)
|
264 |
+
|
265 |
+
if new_key is not None:
|
266 |
+
sd[new_key] = v
|
267 |
+
|
268 |
+
pl_sd.clear()
|
269 |
+
pl_sd.update(sd)
|
270 |
+
|
271 |
+
return pl_sd
|
272 |
+
|
273 |
+
|
274 |
+
def read_metadata_from_safetensors(filename):
|
275 |
+
import json
|
276 |
+
|
277 |
+
with open(filename, mode="rb") as file:
|
278 |
+
metadata_len = file.read(8)
|
279 |
+
metadata_len = int.from_bytes(metadata_len, "little")
|
280 |
+
json_start = file.read(2)
|
281 |
+
|
282 |
+
assert metadata_len > 2 and json_start in (b'{"', b"{'"), f"{filename} is not a safetensors file"
|
283 |
+
json_data = json_start + file.read(metadata_len-2)
|
284 |
+
json_obj = json.loads(json_data)
|
285 |
+
|
286 |
+
res = {}
|
287 |
+
for k, v in json_obj.get("__metadata__", {}).items():
|
288 |
+
res[k] = v
|
289 |
+
if isinstance(v, str) and v[0:1] == '{':
|
290 |
+
try:
|
291 |
+
res[k] = json.loads(v)
|
292 |
+
except Exception:
|
293 |
+
pass
|
294 |
+
|
295 |
+
return res
|
296 |
+
|
297 |
+
|
298 |
+
def read_state_dict(checkpoint_file, print_global_state=False, map_location=None):
|
299 |
+
_, extension = os.path.splitext(checkpoint_file)
|
300 |
+
if extension.lower() == ".safetensors":
|
301 |
+
device = map_location or shared.weight_load_location or devices.get_optimal_device_name()
|
302 |
+
|
303 |
+
if not shared.opts.disable_mmap_load_safetensors:
|
304 |
+
pl_sd = safetensors.torch.load_file(checkpoint_file, device=device)
|
305 |
+
else:
|
306 |
+
pl_sd = safetensors.torch.load(open(checkpoint_file, 'rb').read())
|
307 |
+
pl_sd = {k: v.to(device) for k, v in pl_sd.items()}
|
308 |
+
else:
|
309 |
+
pl_sd = torch.load(checkpoint_file, map_location=map_location or shared.weight_load_location)
|
310 |
+
|
311 |
+
if print_global_state and "global_step" in pl_sd:
|
312 |
+
print(f"Global Step: {pl_sd['global_step']}")
|
313 |
+
|
314 |
+
sd = get_state_dict_from_checkpoint(pl_sd)
|
315 |
+
return sd
|
316 |
+
|
317 |
+
|
318 |
+
def get_checkpoint_state_dict(checkpoint_info: CheckpointInfo, timer):
|
319 |
+
sd_model_hash = checkpoint_info.calculate_shorthash()
|
320 |
+
timer.record("calculate hash")
|
321 |
+
|
322 |
+
if checkpoint_info in checkpoints_loaded:
|
323 |
+
# use checkpoint cache
|
324 |
+
print(f"Loading weights [{sd_model_hash}] from cache")
|
325 |
+
# move to end as latest
|
326 |
+
checkpoints_loaded.move_to_end(checkpoint_info)
|
327 |
+
return checkpoints_loaded[checkpoint_info]
|
328 |
+
|
329 |
+
print(f"Loading weights [{sd_model_hash}] from {checkpoint_info.filename}")
|
330 |
+
res = read_state_dict(checkpoint_info.filename)
|
331 |
+
timer.record("load weights from disk")
|
332 |
+
|
333 |
+
return res
|
334 |
+
|
335 |
+
|
336 |
+
class SkipWritingToConfig:
|
337 |
+
"""This context manager prevents load_model_weights from writing checkpoint name to the config when it loads weight."""
|
338 |
+
|
339 |
+
skip = False
|
340 |
+
previous = None
|
341 |
+
|
342 |
+
def __enter__(self):
|
343 |
+
self.previous = SkipWritingToConfig.skip
|
344 |
+
SkipWritingToConfig.skip = True
|
345 |
+
return self
|
346 |
+
|
347 |
+
def __exit__(self, exc_type, exc_value, exc_traceback):
|
348 |
+
SkipWritingToConfig.skip = self.previous
|
349 |
+
|
350 |
+
|
351 |
+
def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer):
|
352 |
+
sd_model_hash = checkpoint_info.calculate_shorthash()
|
353 |
+
timer.record("calculate hash")
|
354 |
+
|
355 |
+
if not SkipWritingToConfig.skip:
|
356 |
+
shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title
|
357 |
+
|
358 |
+
if state_dict is None:
|
359 |
+
state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
|
360 |
+
|
361 |
+
model.is_sdxl = hasattr(model, 'conditioner')
|
362 |
+
model.is_sd2 = not model.is_sdxl and hasattr(model.cond_stage_model, 'model')
|
363 |
+
model.is_sd1 = not model.is_sdxl and not model.is_sd2
|
364 |
+
model.is_ssd = model.is_sdxl and 'model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight' not in state_dict.keys()
|
365 |
+
if model.is_sdxl:
|
366 |
+
sd_models_xl.extend_sdxl(model)
|
367 |
+
|
368 |
+
if model.is_ssd:
|
369 |
+
sd_hijack.model_hijack.convert_sdxl_to_ssd(model)
|
370 |
+
|
371 |
+
if shared.opts.sd_checkpoint_cache > 0:
|
372 |
+
# cache newly loaded model
|
373 |
+
checkpoints_loaded[checkpoint_info] = state_dict.copy()
|
374 |
+
|
375 |
+
model.load_state_dict(state_dict, strict=False)
|
376 |
+
timer.record("apply weights to model")
|
377 |
+
|
378 |
+
del state_dict
|
379 |
+
|
380 |
+
if shared.cmd_opts.opt_channelslast:
|
381 |
+
model.to(memory_format=torch.channels_last)
|
382 |
+
timer.record("apply channels_last")
|
383 |
+
|
384 |
+
if shared.cmd_opts.no_half:
|
385 |
+
model.float()
|
386 |
+
devices.dtype_unet = torch.float32
|
387 |
+
timer.record("apply float()")
|
388 |
+
else:
|
389 |
+
vae = model.first_stage_model
|
390 |
+
depth_model = getattr(model, 'depth_model', None)
|
391 |
+
|
392 |
+
# with --no-half-vae, remove VAE from model when doing half() to prevent its weights from being converted to float16
|
393 |
+
if shared.cmd_opts.no_half_vae:
|
394 |
+
model.first_stage_model = None
|
395 |
+
# with --upcast-sampling, don't convert the depth model weights to float16
|
396 |
+
if shared.cmd_opts.upcast_sampling and depth_model:
|
397 |
+
model.depth_model = None
|
398 |
+
|
399 |
+
model.half()
|
400 |
+
model.first_stage_model = vae
|
401 |
+
if depth_model:
|
402 |
+
model.depth_model = depth_model
|
403 |
+
|
404 |
+
devices.dtype_unet = torch.float16
|
405 |
+
timer.record("apply half()")
|
406 |
+
|
407 |
+
devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16
|
408 |
+
|
409 |
+
model.first_stage_model.to(devices.dtype_vae)
|
410 |
+
timer.record("apply dtype to VAE")
|
411 |
+
|
412 |
+
# clean up cache if limit is reached
|
413 |
+
while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache:
|
414 |
+
checkpoints_loaded.popitem(last=False)
|
415 |
+
|
416 |
+
model.sd_model_hash = sd_model_hash
|
417 |
+
model.sd_model_checkpoint = checkpoint_info.filename
|
418 |
+
model.sd_checkpoint_info = checkpoint_info
|
419 |
+
shared.opts.data["sd_checkpoint_hash"] = checkpoint_info.sha256
|
420 |
+
|
421 |
+
if hasattr(model, 'logvar'):
|
422 |
+
model.logvar = model.logvar.to(devices.device) # fix for training
|
423 |
+
|
424 |
+
sd_vae.delete_base_vae()
|
425 |
+
sd_vae.clear_loaded_vae()
|
426 |
+
vae_file, vae_source = sd_vae.resolve_vae(checkpoint_info.filename).tuple()
|
427 |
+
sd_vae.load_vae(model, vae_file, vae_source)
|
428 |
+
timer.record("load VAE")
|
429 |
+
|
430 |
+
|
431 |
+
def enable_midas_autodownload():
|
432 |
+
"""
|
433 |
+
Gives the ldm.modules.midas.api.load_model function automatic downloading.
|
434 |
+
|
435 |
+
When the 512-depth-ema model, and other future models like it, is loaded,
|
436 |
+
it calls midas.api.load_model to load the associated midas depth model.
|
437 |
+
This function applies a wrapper to download the model to the correct
|
438 |
+
location automatically.
|
439 |
+
"""
|
440 |
+
|
441 |
+
midas_path = os.path.join(paths.models_path, 'midas')
|
442 |
+
|
443 |
+
# stable-diffusion-stability-ai hard-codes the midas model path to
|
444 |
+
# a location that differs from where other scripts using this model look.
|
445 |
+
# HACK: Overriding the path here.
|
446 |
+
for k, v in midas.api.ISL_PATHS.items():
|
447 |
+
file_name = os.path.basename(v)
|
448 |
+
midas.api.ISL_PATHS[k] = os.path.join(midas_path, file_name)
|
449 |
+
|
450 |
+
midas_urls = {
|
451 |
+
"dpt_large": "https://github.com/intel-isl/DPT/releases/download/1_0/dpt_large-midas-2f21e586.pt",
|
452 |
+
"dpt_hybrid": "https://github.com/intel-isl/DPT/releases/download/1_0/dpt_hybrid-midas-501f0c75.pt",
|
453 |
+
"midas_v21": "https://github.com/AlexeyAB/MiDaS/releases/download/midas_dpt/midas_v21-f6b98070.pt",
|
454 |
+
"midas_v21_small": "https://github.com/AlexeyAB/MiDaS/releases/download/midas_dpt/midas_v21_small-70d6b9c8.pt",
|
455 |
+
}
|
456 |
+
|
457 |
+
midas.api.load_model_inner = midas.api.load_model
|
458 |
+
|
459 |
+
def load_model_wrapper(model_type):
|
460 |
+
path = midas.api.ISL_PATHS[model_type]
|
461 |
+
if not os.path.exists(path):
|
462 |
+
if not os.path.exists(midas_path):
|
463 |
+
mkdir(midas_path)
|
464 |
+
|
465 |
+
print(f"Downloading midas model weights for {model_type} to {path}")
|
466 |
+
request.urlretrieve(midas_urls[model_type], path)
|
467 |
+
print(f"{model_type} downloaded")
|
468 |
+
|
469 |
+
return midas.api.load_model_inner(model_type)
|
470 |
+
|
471 |
+
midas.api.load_model = load_model_wrapper
|
472 |
+
|
473 |
+
|
474 |
+
def patch_given_betas():
|
475 |
+
import ldm.models.diffusion.ddpm
|
476 |
+
|
477 |
+
def patched_register_schedule(*args, **kwargs):
|
478 |
+
"""a modified version of register_schedule function that converts plain list from Omegaconf into numpy"""
|
479 |
+
|
480 |
+
if isinstance(args[1], ListConfig):
|
481 |
+
args = (args[0], np.array(args[1]), *args[2:])
|
482 |
+
|
483 |
+
original_register_schedule(*args, **kwargs)
|
484 |
+
|
485 |
+
original_register_schedule = patches.patch(__name__, ldm.models.diffusion.ddpm.DDPM, 'register_schedule', patched_register_schedule)
|
486 |
+
|
487 |
+
|
488 |
+
def repair_config(sd_config):
|
489 |
+
|
490 |
+
if not hasattr(sd_config.model.params, "use_ema"):
|
491 |
+
sd_config.model.params.use_ema = False
|
492 |
+
|
493 |
+
if hasattr(sd_config.model.params, 'unet_config'):
|
494 |
+
if shared.cmd_opts.no_half:
|
495 |
+
sd_config.model.params.unet_config.params.use_fp16 = False
|
496 |
+
elif shared.cmd_opts.upcast_sampling:
|
497 |
+
sd_config.model.params.unet_config.params.use_fp16 = True
|
498 |
+
|
499 |
+
if getattr(sd_config.model.params.first_stage_config.params.ddconfig, "attn_type", None) == "vanilla-xformers" and not shared.xformers_available:
|
500 |
+
sd_config.model.params.first_stage_config.params.ddconfig.attn_type = "vanilla"
|
501 |
+
|
502 |
+
# For UnCLIP-L, override the hardcoded karlo directory
|
503 |
+
if hasattr(sd_config.model.params, "noise_aug_config") and hasattr(sd_config.model.params.noise_aug_config.params, "clip_stats_path"):
|
504 |
+
karlo_path = os.path.join(paths.models_path, 'karlo')
|
505 |
+
sd_config.model.params.noise_aug_config.params.clip_stats_path = sd_config.model.params.noise_aug_config.params.clip_stats_path.replace("checkpoints/karlo_models", karlo_path)
|
506 |
+
|
507 |
+
|
508 |
+
sd1_clip_weight = 'cond_stage_model.transformer.text_model.embeddings.token_embedding.weight'
|
509 |
+
sd2_clip_weight = 'cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight'
|
510 |
+
sdxl_clip_weight = 'conditioner.embedders.1.model.ln_final.weight'
|
511 |
+
sdxl_refiner_clip_weight = 'conditioner.embedders.0.model.ln_final.weight'
|
512 |
+
|
513 |
+
|
514 |
+
class SdModelData:
|
515 |
+
def __init__(self):
|
516 |
+
self.sd_model = None
|
517 |
+
self.loaded_sd_models = []
|
518 |
+
self.was_loaded_at_least_once = False
|
519 |
+
self.lock = threading.Lock()
|
520 |
+
|
521 |
+
def get_sd_model(self):
|
522 |
+
if self.was_loaded_at_least_once:
|
523 |
+
return self.sd_model
|
524 |
+
|
525 |
+
if self.sd_model is None:
|
526 |
+
with self.lock:
|
527 |
+
if self.sd_model is not None or self.was_loaded_at_least_once:
|
528 |
+
return self.sd_model
|
529 |
+
|
530 |
+
try:
|
531 |
+
load_model()
|
532 |
+
|
533 |
+
except Exception as e:
|
534 |
+
errors.display(e, "loading stable diffusion model", full_traceback=True)
|
535 |
+
print("", file=sys.stderr)
|
536 |
+
print("Stable diffusion model failed to load", file=sys.stderr)
|
537 |
+
self.sd_model = None
|
538 |
+
|
539 |
+
return self.sd_model
|
540 |
+
|
541 |
+
def set_sd_model(self, v, already_loaded=False):
|
542 |
+
self.sd_model = v
|
543 |
+
if already_loaded:
|
544 |
+
sd_vae.base_vae = getattr(v, "base_vae", None)
|
545 |
+
sd_vae.loaded_vae_file = getattr(v, "loaded_vae_file", None)
|
546 |
+
sd_vae.checkpoint_info = v.sd_checkpoint_info
|
547 |
+
|
548 |
+
try:
|
549 |
+
self.loaded_sd_models.remove(v)
|
550 |
+
except ValueError:
|
551 |
+
pass
|
552 |
+
|
553 |
+
if v is not None:
|
554 |
+
self.loaded_sd_models.insert(0, v)
|
555 |
+
|
556 |
+
|
557 |
+
model_data = SdModelData()
|
558 |
+
|
559 |
+
|
560 |
+
def get_empty_cond(sd_model):
|
561 |
+
|
562 |
+
p = processing.StableDiffusionProcessingTxt2Img()
|
563 |
+
extra_networks.activate(p, {})
|
564 |
+
|
565 |
+
if hasattr(sd_model, 'conditioner'):
|
566 |
+
d = sd_model.get_learned_conditioning([""])
|
567 |
+
return d['crossattn']
|
568 |
+
else:
|
569 |
+
return sd_model.cond_stage_model([""])
|
570 |
+
|
571 |
+
|
572 |
+
def send_model_to_cpu(m):
|
573 |
+
if m.lowvram:
|
574 |
+
lowvram.send_everything_to_cpu()
|
575 |
+
else:
|
576 |
+
m.to(devices.cpu)
|
577 |
+
|
578 |
+
devices.torch_gc()
|
579 |
+
|
580 |
+
|
581 |
+
def model_target_device(m):
|
582 |
+
if lowvram.is_needed(m):
|
583 |
+
return devices.cpu
|
584 |
+
else:
|
585 |
+
return devices.device
|
586 |
+
|
587 |
+
|
588 |
+
def send_model_to_device(m):
|
589 |
+
lowvram.apply(m)
|
590 |
+
|
591 |
+
if not m.lowvram:
|
592 |
+
m.to(shared.device)
|
593 |
+
|
594 |
+
|
595 |
+
def send_model_to_trash(m):
|
596 |
+
m.to(device="meta")
|
597 |
+
devices.torch_gc()
|
598 |
+
|
599 |
+
|
600 |
+
def load_model(checkpoint_info=None, already_loaded_state_dict=None):
|
601 |
+
from modules import sd_hijack
|
602 |
+
checkpoint_info = checkpoint_info or select_checkpoint()
|
603 |
+
|
604 |
+
timer = Timer()
|
605 |
+
|
606 |
+
if model_data.sd_model:
|
607 |
+
send_model_to_trash(model_data.sd_model)
|
608 |
+
model_data.sd_model = None
|
609 |
+
devices.torch_gc()
|
610 |
+
|
611 |
+
timer.record("unload existing model")
|
612 |
+
|
613 |
+
if already_loaded_state_dict is not None:
|
614 |
+
state_dict = already_loaded_state_dict
|
615 |
+
else:
|
616 |
+
state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
|
617 |
+
|
618 |
+
checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info)
|
619 |
+
clip_is_included_into_sd = any(x for x in [sd1_clip_weight, sd2_clip_weight, sdxl_clip_weight, sdxl_refiner_clip_weight] if x in state_dict)
|
620 |
+
|
621 |
+
timer.record("find config")
|
622 |
+
|
623 |
+
sd_config = OmegaConf.load(checkpoint_config)
|
624 |
+
repair_config(sd_config)
|
625 |
+
|
626 |
+
timer.record("load config")
|
627 |
+
|
628 |
+
print(f"Creating model from config: {checkpoint_config}")
|
629 |
+
|
630 |
+
sd_model = None
|
631 |
+
try:
|
632 |
+
with sd_disable_initialization.DisableInitialization(disable_clip=clip_is_included_into_sd or shared.cmd_opts.do_not_download_clip):
|
633 |
+
with sd_disable_initialization.InitializeOnMeta():
|
634 |
+
sd_model = instantiate_from_config(sd_config.model)
|
635 |
+
|
636 |
+
except Exception as e:
|
637 |
+
errors.display(e, "creating model quickly", full_traceback=True)
|
638 |
+
|
639 |
+
if sd_model is None:
|
640 |
+
print('Failed to create model quickly; will retry using slow method.', file=sys.stderr)
|
641 |
+
|
642 |
+
with sd_disable_initialization.InitializeOnMeta():
|
643 |
+
sd_model = instantiate_from_config(sd_config.model)
|
644 |
+
|
645 |
+
sd_model.used_config = checkpoint_config
|
646 |
+
|
647 |
+
timer.record("create model")
|
648 |
+
|
649 |
+
if shared.cmd_opts.no_half:
|
650 |
+
weight_dtype_conversion = None
|
651 |
+
else:
|
652 |
+
weight_dtype_conversion = {
|
653 |
+
'first_stage_model': None,
|
654 |
+
'': torch.float16,
|
655 |
+
}
|
656 |
+
|
657 |
+
with sd_disable_initialization.LoadStateDictOnMeta(state_dict, device=model_target_device(sd_model), weight_dtype_conversion=weight_dtype_conversion):
|
658 |
+
load_model_weights(sd_model, checkpoint_info, state_dict, timer)
|
659 |
+
timer.record("load weights from state dict")
|
660 |
+
|
661 |
+
send_model_to_device(sd_model)
|
662 |
+
timer.record("move model to device")
|
663 |
+
|
664 |
+
sd_hijack.model_hijack.hijack(sd_model)
|
665 |
+
|
666 |
+
timer.record("hijack")
|
667 |
+
|
668 |
+
sd_model.eval()
|
669 |
+
model_data.set_sd_model(sd_model)
|
670 |
+
model_data.was_loaded_at_least_once = True
|
671 |
+
|
672 |
+
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) # Reload embeddings after model load as they may or may not fit the model
|
673 |
+
|
674 |
+
timer.record("load textual inversion embeddings")
|
675 |
+
|
676 |
+
script_callbacks.model_loaded_callback(sd_model)
|
677 |
+
|
678 |
+
timer.record("scripts callbacks")
|
679 |
+
|
680 |
+
with devices.autocast(), torch.no_grad():
|
681 |
+
sd_model.cond_stage_model_empty_prompt = get_empty_cond(sd_model)
|
682 |
+
|
683 |
+
timer.record("calculate empty prompt")
|
684 |
+
|
685 |
+
print(f"Model loaded in {timer.summary()}.")
|
686 |
+
|
687 |
+
return sd_model
|
688 |
+
|
689 |
+
|
690 |
+
def reuse_model_from_already_loaded(sd_model, checkpoint_info, timer):
|
691 |
+
"""
|
692 |
+
Checks if the desired checkpoint from checkpoint_info is not already loaded in model_data.loaded_sd_models.
|
693 |
+
If it is loaded, returns that (moving it to GPU if necessary, and moving the currently loadded model to CPU if necessary).
|
694 |
+
If not, returns the model that can be used to load weights from checkpoint_info's file.
|
695 |
+
If no such model exists, returns None.
|
696 |
+
Additionaly deletes loaded models that are over the limit set in settings (sd_checkpoints_limit).
|
697 |
+
"""
|
698 |
+
|
699 |
+
already_loaded = None
|
700 |
+
for i in reversed(range(len(model_data.loaded_sd_models))):
|
701 |
+
loaded_model = model_data.loaded_sd_models[i]
|
702 |
+
if loaded_model.sd_checkpoint_info.filename == checkpoint_info.filename:
|
703 |
+
already_loaded = loaded_model
|
704 |
+
continue
|
705 |
+
|
706 |
+
if len(model_data.loaded_sd_models) > shared.opts.sd_checkpoints_limit > 0:
|
707 |
+
print(f"Unloading model {len(model_data.loaded_sd_models)} over the limit of {shared.opts.sd_checkpoints_limit}: {loaded_model.sd_checkpoint_info.title}")
|
708 |
+
model_data.loaded_sd_models.pop()
|
709 |
+
send_model_to_trash(loaded_model)
|
710 |
+
timer.record("send model to trash")
|
711 |
+
|
712 |
+
if shared.opts.sd_checkpoints_keep_in_cpu:
|
713 |
+
send_model_to_cpu(sd_model)
|
714 |
+
timer.record("send model to cpu")
|
715 |
+
|
716 |
+
if already_loaded is not None:
|
717 |
+
send_model_to_device(already_loaded)
|
718 |
+
timer.record("send model to device")
|
719 |
+
|
720 |
+
model_data.set_sd_model(already_loaded, already_loaded=True)
|
721 |
+
|
722 |
+
if not SkipWritingToConfig.skip:
|
723 |
+
shared.opts.data["sd_model_checkpoint"] = already_loaded.sd_checkpoint_info.title
|
724 |
+
shared.opts.data["sd_checkpoint_hash"] = already_loaded.sd_checkpoint_info.sha256
|
725 |
+
|
726 |
+
print(f"Using already loaded model {already_loaded.sd_checkpoint_info.title}: done in {timer.summary()}")
|
727 |
+
sd_vae.reload_vae_weights(already_loaded)
|
728 |
+
return model_data.sd_model
|
729 |
+
elif shared.opts.sd_checkpoints_limit > 1 and len(model_data.loaded_sd_models) < shared.opts.sd_checkpoints_limit:
|
730 |
+
print(f"Loading model {checkpoint_info.title} ({len(model_data.loaded_sd_models) + 1} out of {shared.opts.sd_checkpoints_limit})")
|
731 |
+
|
732 |
+
model_data.sd_model = None
|
733 |
+
load_model(checkpoint_info)
|
734 |
+
return model_data.sd_model
|
735 |
+
elif len(model_data.loaded_sd_models) > 0:
|
736 |
+
sd_model = model_data.loaded_sd_models.pop()
|
737 |
+
model_data.sd_model = sd_model
|
738 |
+
|
739 |
+
sd_vae.base_vae = getattr(sd_model, "base_vae", None)
|
740 |
+
sd_vae.loaded_vae_file = getattr(sd_model, "loaded_vae_file", None)
|
741 |
+
sd_vae.checkpoint_info = sd_model.sd_checkpoint_info
|
742 |
+
|
743 |
+
print(f"Reusing loaded model {sd_model.sd_checkpoint_info.title} to load {checkpoint_info.title}")
|
744 |
+
return sd_model
|
745 |
+
else:
|
746 |
+
return None
|
747 |
+
|
748 |
+
|
749 |
+
def reload_model_weights(sd_model=None, info=None):
|
750 |
+
checkpoint_info = info or select_checkpoint()
|
751 |
+
|
752 |
+
timer = Timer()
|
753 |
+
|
754 |
+
if not sd_model:
|
755 |
+
sd_model = model_data.sd_model
|
756 |
+
|
757 |
+
if sd_model is None: # previous model load failed
|
758 |
+
current_checkpoint_info = None
|
759 |
+
else:
|
760 |
+
current_checkpoint_info = sd_model.sd_checkpoint_info
|
761 |
+
if sd_model.sd_model_checkpoint == checkpoint_info.filename:
|
762 |
+
return sd_model
|
763 |
+
|
764 |
+
sd_model = reuse_model_from_already_loaded(sd_model, checkpoint_info, timer)
|
765 |
+
if sd_model is not None and sd_model.sd_checkpoint_info.filename == checkpoint_info.filename:
|
766 |
+
return sd_model
|
767 |
+
|
768 |
+
if sd_model is not None:
|
769 |
+
sd_unet.apply_unet("None")
|
770 |
+
send_model_to_cpu(sd_model)
|
771 |
+
sd_hijack.model_hijack.undo_hijack(sd_model)
|
772 |
+
|
773 |
+
state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
|
774 |
+
|
775 |
+
checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info)
|
776 |
+
|
777 |
+
timer.record("find config")
|
778 |
+
|
779 |
+
if sd_model is None or checkpoint_config != sd_model.used_config:
|
780 |
+
if sd_model is not None:
|
781 |
+
send_model_to_trash(sd_model)
|
782 |
+
|
783 |
+
load_model(checkpoint_info, already_loaded_state_dict=state_dict)
|
784 |
+
return model_data.sd_model
|
785 |
+
|
786 |
+
try:
|
787 |
+
load_model_weights(sd_model, checkpoint_info, state_dict, timer)
|
788 |
+
except Exception:
|
789 |
+
print("Failed to load checkpoint, restoring previous")
|
790 |
+
load_model_weights(sd_model, current_checkpoint_info, None, timer)
|
791 |
+
raise
|
792 |
+
finally:
|
793 |
+
sd_hijack.model_hijack.hijack(sd_model)
|
794 |
+
timer.record("hijack")
|
795 |
+
|
796 |
+
script_callbacks.model_loaded_callback(sd_model)
|
797 |
+
timer.record("script callbacks")
|
798 |
+
|
799 |
+
if not sd_model.lowvram:
|
800 |
+
sd_model.to(devices.device)
|
801 |
+
timer.record("move model to device")
|
802 |
+
|
803 |
+
print(f"Weights loaded in {timer.summary()}.")
|
804 |
+
|
805 |
+
model_data.set_sd_model(sd_model)
|
806 |
+
sd_unet.apply_unet()
|
807 |
+
|
808 |
+
return sd_model
|
809 |
+
|
810 |
+
|
811 |
+
def unload_model_weights(sd_model=None, info=None):
|
812 |
+
send_model_to_cpu(sd_model or shared.sd_model)
|
813 |
+
|
814 |
+
return sd_model
|
815 |
+
|
816 |
+
|
817 |
+
def apply_token_merging(sd_model, token_merging_ratio):
|
818 |
+
"""
|
819 |
+
Applies speed and memory optimizations from tomesd.
|
820 |
+
"""
|
821 |
+
|
822 |
+
current_token_merging_ratio = getattr(sd_model, 'applied_token_merged_ratio', 0)
|
823 |
+
|
824 |
+
if current_token_merging_ratio == token_merging_ratio:
|
825 |
+
return
|
826 |
+
|
827 |
+
if current_token_merging_ratio > 0:
|
828 |
+
tomesd.remove_patch(sd_model)
|
829 |
+
|
830 |
+
if token_merging_ratio > 0:
|
831 |
+
tomesd.apply_patch(
|
832 |
+
sd_model,
|
833 |
+
ratio=token_merging_ratio,
|
834 |
+
use_rand=False, # can cause issues with some samplers
|
835 |
+
merge_attn=True,
|
836 |
+
merge_crossattn=False,
|
837 |
+
merge_mlp=False
|
838 |
+
)
|
839 |
+
|
840 |
+
sd_model.applied_token_merged_ratio = token_merging_ratio
|