Spaces:
Runtime error
Runtime error
import os | |
import re | |
import lora_patches | |
import functools | |
import network | |
import torch | |
from typing import Union | |
from modules import shared, sd_models, errors, scripts | |
from ldm_patched.modules.utils import load_torch_file | |
from ldm_patched.modules.sd import load_lora_for_models | |
def load_lora_state_dict(filename): | |
return load_torch_file(filename, safe_load=True) | |
def convert_diffusers_name_to_compvis(key, is_sd2): | |
pass | |
def assign_network_names_to_compvis_modules(sd_model): | |
pass | |
def load_network(name, network_on_disk): | |
pass | |
def purge_networks_from_memory(): | |
pass | |
def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=None): | |
global lora_state_dict_cache | |
current_sd = sd_models.model_data.get_sd_model() | |
if current_sd is None: | |
return | |
networks_on_disk = [available_networks.get(name, None) if name.lower() in forbidden_network_aliases else available_network_aliases.get(name, None) for name in names] | |
if any(x is None for x in networks_on_disk): | |
list_available_networks() | |
networks_on_disk = [available_networks.get(name, None) if name.lower() in forbidden_network_aliases else available_network_aliases.get(name, None) for name in names] | |
compiled_lora_targets = [] | |
for a, b, c in zip(networks_on_disk, unet_multipliers, te_multipliers): | |
compiled_lora_targets.append([a.filename, b, c]) | |
compiled_lora_targets_hash = str(compiled_lora_targets) | |
if current_sd.current_lora_hash == compiled_lora_targets_hash: | |
return | |
current_sd.current_lora_hash = compiled_lora_targets_hash | |
current_sd.forge_objects.unet = current_sd.forge_objects_original.unet | |
current_sd.forge_objects.clip = current_sd.forge_objects_original.clip | |
for filename, strength_model, strength_clip in compiled_lora_targets: | |
lora_sd = load_lora_state_dict(filename) | |
current_sd.forge_objects.unet, current_sd.forge_objects.clip = load_lora_for_models( | |
current_sd.forge_objects.unet, current_sd.forge_objects.clip, lora_sd, strength_model, strength_clip) | |
current_sd.forge_objects_after_applying_lora = current_sd.forge_objects.shallow_copy() | |
return | |
def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention]): | |
pass | |
def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention]): | |
pass | |
def network_forward(org_module, input, original_forward): | |
pass | |
def network_reset_cached_weight(self: Union[torch.nn.Conv2d, torch.nn.Linear]): | |
pass | |
def network_Linear_forward(self, input): | |
pass | |
def network_Linear_load_state_dict(self, *args, **kwargs): | |
pass | |
def network_Conv2d_forward(self, input): | |
pass | |
def network_Conv2d_load_state_dict(self, *args, **kwargs): | |
pass | |
def network_GroupNorm_forward(self, input): | |
pass | |
def network_GroupNorm_load_state_dict(self, *args, **kwargs): | |
pass | |
def network_LayerNorm_forward(self, input): | |
pass | |
def network_LayerNorm_load_state_dict(self, *args, **kwargs): | |
pass | |
def network_MultiheadAttention_forward(self, *args, **kwargs): | |
pass | |
def network_MultiheadAttention_load_state_dict(self, *args, **kwargs): | |
pass | |
def list_available_networks(): | |
available_networks.clear() | |
available_network_aliases.clear() | |
forbidden_network_aliases.clear() | |
available_network_hash_lookup.clear() | |
forbidden_network_aliases.update({"none": 1, "Addams": 1}) | |
os.makedirs(shared.cmd_opts.lora_dir, exist_ok=True) | |
candidates = list(shared.walk_files(shared.cmd_opts.lora_dir, allowed_extensions=[".pt", ".ckpt", ".safetensors"])) | |
candidates += list(shared.walk_files(shared.cmd_opts.lyco_dir_backcompat, allowed_extensions=[".pt", ".ckpt", ".safetensors"])) | |
for filename in candidates: | |
if os.path.isdir(filename): | |
continue | |
name = os.path.splitext(os.path.basename(filename))[0] | |
try: | |
entry = network.NetworkOnDisk(name, filename) | |
except OSError: # should catch FileNotFoundError and PermissionError etc. | |
errors.report(f"Failed to load network {name} from {filename}", exc_info=True) | |
continue | |
available_networks[name] = entry | |
if entry.alias in available_network_aliases: | |
forbidden_network_aliases[entry.alias.lower()] = 1 | |
available_network_aliases[name] = entry | |
available_network_aliases[entry.alias] = entry | |
re_network_name = re.compile(r"(.*)\s*\([0-9a-fA-F]+\)") | |
def infotext_pasted(infotext, params): | |
if "AddNet Module 1" in [x[1] for x in scripts.scripts_txt2img.infotext_fields]: | |
return # if the other extension is active, it will handle those fields, no need to do anything | |
added = [] | |
for k in params: | |
if not k.startswith("AddNet Model "): | |
continue | |
num = k[13:] | |
if params.get("AddNet Module " + num) != "LoRA": | |
continue | |
name = params.get("AddNet Model " + num) | |
if name is None: | |
continue | |
m = re_network_name.match(name) | |
if m: | |
name = m.group(1) | |
multiplier = params.get("AddNet Weight A " + num, "1.0") | |
added.append(f"<lora:{name}:{multiplier}>") | |
if added: | |
params["Prompt"] += "\n" + "".join(added) | |
originals: lora_patches.LoraPatches = None | |
extra_network_lora = None | |
available_networks = {} | |
available_network_aliases = {} | |
loaded_networks = [] | |
loaded_bundle_embeddings = {} | |
networks_in_memory = {} | |
available_network_hash_lookup = {} | |
forbidden_network_aliases = {} | |
list_available_networks() | |