Spaces:
Runtime error
Runtime error
from __future__ import annotations | |
import os | |
import re | |
import torch | |
import network | |
import functools | |
from backend.args import dynamic_args | |
from modules import shared, sd_models, errors, scripts | |
from backend.utils import load_torch_file | |
from backend.patcher.lora import model_lora_keys_clip, model_lora_keys_unet, load_lora | |
def load_lora_for_models(model, clip, lora, strength_model, strength_clip, filename='default', online_mode=False): | |
model_flag = type(model.model).__name__ if model is not None else 'default' | |
unet_keys = model_lora_keys_unet(model.model) if model is not None else {} | |
clip_keys = model_lora_keys_clip(clip.cond_stage_model) if clip is not None else {} | |
lora_unmatch = lora | |
lora_unet, lora_unmatch = load_lora(lora_unmatch, unet_keys) | |
lora_clip, lora_unmatch = load_lora(lora_unmatch, clip_keys) | |
if len(lora_unmatch) > 12: | |
print(f'[LORA] LoRA version mismatch for {model_flag}: {filename}') | |
return model, clip | |
if len(lora_unmatch) > 0: | |
print(f'[LORA] Loading {filename} for {model_flag} with unmatched keys {list(lora_unmatch.keys())}') | |
new_model = model.clone() if model is not None else None | |
new_clip = clip.clone() if clip is not None else None | |
if new_model is not None and len(lora_unet) > 0: | |
loaded_keys = new_model.add_patches(filename=filename, patches=lora_unet, strength_patch=strength_model, online_mode=online_mode) | |
skipped_keys = [item for item in lora_unet if item not in loaded_keys] | |
if len(skipped_keys) > 12: | |
print(f'[LORA] Mismatch {filename} for {model_flag}-UNet with {len(skipped_keys)} keys mismatched in {len(loaded_keys)} keys') | |
else: | |
print(f'[LORA] Loaded {filename} for {model_flag}-UNet with {len(loaded_keys)} keys at weight {strength_model} (skipped {len(skipped_keys)} keys) with on_the_fly = {online_mode}') | |
model = new_model | |
if new_clip is not None and len(lora_clip) > 0: | |
loaded_keys = new_clip.add_patches(filename=filename, patches=lora_clip, strength_patch=strength_clip, online_mode=online_mode) | |
skipped_keys = [item for item in lora_clip if item not in loaded_keys] | |
if len(skipped_keys) > 12: | |
print(f'[LORA] Mismatch {filename} for {model_flag}-CLIP with {len(skipped_keys)} keys mismatched in {len(loaded_keys)} keys') | |
else: | |
print(f'[LORA] Loaded {filename} for {model_flag}-CLIP with {len(loaded_keys)} keys at weight {strength_clip} (skipped {len(skipped_keys)} keys) with on_the_fly = {online_mode}') | |
clip = new_clip | |
return model, clip | |
def load_lora_state_dict(filename): | |
return load_torch_file(filename, safe_load=True) | |
def load_network(name, network_on_disk): | |
net = network.Network(name, network_on_disk) | |
net.mtime = os.path.getmtime(network_on_disk.filename) | |
return net | |
def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=None): | |
global lora_state_dict_cache | |
current_sd = sd_models.model_data.get_sd_model() | |
if current_sd is None: | |
return | |
loaded_networks.clear() | |
unavailable_networks = [] | |
for name in names: | |
if name.lower() in forbidden_network_aliases and available_networks.get(name) is None: | |
unavailable_networks.append(name) | |
elif available_network_aliases.get(name) is None: | |
unavailable_networks.append(name) | |
if unavailable_networks: | |
update_available_networks_by_names(unavailable_networks) | |
networks_on_disk = [available_networks.get(name, None) if name.lower() in forbidden_network_aliases else available_network_aliases.get(name, None) for name in names] | |
if any(x is None for x in networks_on_disk): | |
list_available_networks() | |
networks_on_disk = [available_networks.get(name, None) if name.lower() in forbidden_network_aliases else available_network_aliases.get(name, None) for name in names] | |
for i, (network_on_disk, name) in enumerate(zip(networks_on_disk, names)): | |
try: | |
net = load_network(name, network_on_disk) | |
except Exception as e: | |
errors.display(e, f"loading network {network_on_disk.filename}") | |
continue | |
net.mentioned_name = name | |
network_on_disk.read_hash() | |
loaded_networks.append(net) | |
online_mode = dynamic_args.get('online_lora', False) | |
if current_sd.forge_objects.unet.model.storage_dtype in [torch.float32, torch.float16, torch.bfloat16]: | |
online_mode = False | |
compiled_lora_targets = [] | |
for a, b, c in zip(networks_on_disk, unet_multipliers, te_multipliers): | |
compiled_lora_targets.append([a.filename, b, c, online_mode]) | |
compiled_lora_targets_hash = str(compiled_lora_targets) | |
if current_sd.current_lora_hash == compiled_lora_targets_hash: | |
return | |
current_sd.current_lora_hash = compiled_lora_targets_hash | |
current_sd.forge_objects.unet = current_sd.forge_objects_original.unet | |
current_sd.forge_objects.clip = current_sd.forge_objects_original.clip | |
for filename, strength_model, strength_clip, online_mode in compiled_lora_targets: | |
lora_sd = load_lora_state_dict(filename) | |
current_sd.forge_objects.unet, current_sd.forge_objects.clip = load_lora_for_models( | |
current_sd.forge_objects.unet, current_sd.forge_objects.clip, lora_sd, strength_model, strength_clip, | |
filename=filename, online_mode=online_mode) | |
current_sd.forge_objects_after_applying_lora = current_sd.forge_objects.shallow_copy() | |
return | |
def process_network_files(names: list[str] | None = None): | |
candidates = list(shared.walk_files(shared.cmd_opts.lora_dir, allowed_extensions=[".pt", ".ckpt", ".safetensors"])) | |
for filename in candidates: | |
if os.path.isdir(filename): | |
continue | |
name = os.path.splitext(os.path.basename(filename))[0] | |
# if names is provided, only load networks with names in the list | |
if names and name not in names: | |
continue | |
try: | |
entry = network.NetworkOnDisk(name, filename) | |
except OSError: # should catch FileNotFoundError and PermissionError etc. | |
errors.report(f"Failed to load network {name} from {filename}", exc_info=True) | |
continue | |
available_networks[name] = entry | |
if entry.alias in available_network_aliases: | |
forbidden_network_aliases[entry.alias.lower()] = 1 | |
available_network_aliases[name] = entry | |
available_network_aliases[entry.alias] = entry | |
def update_available_networks_by_names(names: list[str]): | |
process_network_files(names) | |
def list_available_networks(): | |
available_networks.clear() | |
available_network_aliases.clear() | |
forbidden_network_aliases.clear() | |
available_network_hash_lookup.clear() | |
forbidden_network_aliases.update({"none": 1, "Addams": 1}) | |
os.makedirs(shared.cmd_opts.lora_dir, exist_ok=True) | |
process_network_files() | |
re_network_name = re.compile(r"(.*)\s*\([0-9a-fA-F]+\)") | |
def infotext_pasted(infotext, params): | |
if "AddNet Module 1" in [x[1] for x in scripts.scripts_txt2img.infotext_fields]: | |
return # if the other extension is active, it will handle those fields, no need to do anything | |
added = [] | |
for k in params: | |
if not k.startswith("AddNet Model "): | |
continue | |
num = k[13:] | |
if params.get("AddNet Module " + num) != "LoRA": | |
continue | |
name = params.get("AddNet Model " + num) | |
if name is None: | |
continue | |
m = re_network_name.match(name) | |
if m: | |
name = m.group(1) | |
multiplier = params.get("AddNet Weight A " + num, "1.0") | |
added.append(f"<lora:{name}:{multiplier}>") | |
if added: | |
params["Prompt"] += "\n" + "".join(added) | |
extra_network_lora = None | |
available_networks = {} | |
available_network_aliases = {} | |
loaded_networks = [] | |
loaded_bundle_embeddings = {} | |
networks_in_memory = {} | |
available_network_hash_lookup = {} | |
forbidden_network_aliases = {} | |
list_available_networks() | |