File size: 9,524 Bytes
7f5a1b6 a77578c 7f5a1b6 c1a5584 7f5a1b6 c1a5584 7f5a1b6 c1a5584 7f5a1b6 c1a5584 7f5a1b6 c1a5584 7f5a1b6 c1a5584 7f5a1b6 c1a5584 7f5a1b6 c1a5584 7f5a1b6 c1a5584 7f5a1b6 c1a5584 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 |
import argparse
import os
from safetensors.torch import load_file, save_file
import toml
import re
from safetensors import safe_open
import math
def parse_key(key):
match = re.match(r"lora_unet_(input|output|up|down)_blocks_(\d+(?:_\d+)?)_(.+)\.(?:alpha|lora_(?:down|up)\.weight)", key)
if match:
return "unet", match.group(1) + "_blocks", match.group(2), match.group(3)
match = re.match(r"lora_unet_(mid_block)_(resnets|attentions)_(\d+)_(.+)\.(?:alpha|lora_(?:down|up)\.weight)", key)
if match:
return "unet", match.group(1), f"{match.group(2)}_{match.group(3)}", match.group(4)
match = re.match(r"lora_unet_(middle_block)_(\d+)_(.+)\.(?:alpha|lora_(?:down|up)\.weight)", key)
if match:
return "unet", match.group(1), match.group(2), match.group(3)
match = re.match(r"lora_te\d+_text_model_encoder_(.+)\.(?:alpha|lora_(?:down|up)\.weight)", key)
if match:
return "text_encoder", "encoder_layers", match.group(1).split("_")[0], "_".join(match.group(1).split("_")[1:])
return None, None, None, None
def extract_lora_hierarchy(lora_tensors, mode="extract"):
lora_hierarchy = {}
lora_key_groups = {"unet": {}, "text_encoder": {}} if mode == "adjust" else None
for key in lora_tensors:
if key.startswith("lora_unet_"):
model_type, block_type, block_num, layer_key = parse_key(key)
if model_type and block_type and layer_key:
parts = layer_key.split("_")
if "transformer_blocks" in layer_key:
grouped_key = "_".join(parts[:3] + [parts[3] if len(parts) > 5 else ""])
elif "attentions" in layer_key:
grouped_key = "_".join(parts[:3] + [parts[3] if len(parts) > 5 else ""])
elif "resnets" in layer_key:
grouped_key = "_".join(parts[:3])
else:
grouped_key = layer_key
if model_type not in lora_hierarchy:
lora_hierarchy[model_type] = {}
if block_type not in lora_hierarchy[model_type]:
lora_hierarchy[model_type][block_type] = {}
if block_num not in lora_hierarchy[model_type][block_type]:
lora_hierarchy[model_type][block_type][block_num] = {}
lora_hierarchy[model_type][block_type][block_num][grouped_key] = 1.0
if mode == "adjust":
group_key = f"..unet_{block_type}_{block_num}_{grouped_key}"
if group_key not in lora_key_groups["unet"]:
lora_key_groups["unet"][group_key] = []
lora_key_groups["unet"][group_key].append(key)
elif key.startswith("lora_te"):
match = re.match(r"(lora_te\d+)_text_model_encoder_layers_(\d+)_(.+)\.(?:alpha|lora_(?:down|up)\.weight)", key)
if match:
model_section = match.group(1)
block_type = "encoder"
block_num = match.group(2)
layer_key = match.group(3)
grouped_key = f"layers_{block_num}__{layer_key}"
if model_section not in lora_hierarchy:
lora_hierarchy[model_section] = {}
if block_type not in lora_hierarchy[model_section]:
lora_hierarchy[model_section][block_type] = {}
lora_hierarchy[model_section][block_type][grouped_key] = 1.0
if mode == "adjust":
group_key = f"..{model_section}_{block_num}_{layer_key}"
lora_key_groups["text_encoder"][group_key] = [key]
return lora_hierarchy if mode == "extract" else lora_key_groups
def adjust_lora_weights(lora_path, toml_path, output_path, multiplier=1.0, remove_zero_weight_keys=True):
try:
lora_tensors = load_file(lora_path)
with safe_open(lora_path, framework="pt") as f:
metadata = f.metadata()
except Exception as e:
raise Exception(f"Error loading LoRA model: {e}")
try:
with open(toml_path, "r") as f:
lora_config = toml.load(f)
except Exception as e:
raise Exception(f"Error loading TOML file: {e}")
lora_key_groups = extract_lora_hierarchy(lora_tensors, mode="adjust")
adjusted_tensors = {}
for model_section, model_config in lora_config.items():
if model_section.startswith("lora_te"):
for block_type, layers in model_config.items():
for layer_key, weight in layers.items():
block_num, layer_name = layer_key.replace("layers_", "").split("__")
group_key = f"..{model_section}_{block_num}_{layer_name}"
if group_key in lora_key_groups["text_encoder"]:
final_weight = weight * multiplier
if not remove_zero_weight_keys or final_weight != 0.0:
for target_key in lora_key_groups["text_encoder"][group_key]:
if target_key.endswith(".alpha"):
final_weight = weight * multiplier
if not remove_zero_weight_keys or final_weight != 0.0:
adjusted_tensors[target_key] = lora_tensors[target_key]
else:
final_weight = weight * multiplier
if not remove_zero_weight_keys or final_weight != 0.0:
adjusted_tensors[target_key] = lora_tensors[target_key] * math.sqrt(final_weight)
else: # unet
for block_type, block_nums in model_config.items():
for block_num, layer_keys in block_nums.items():
for grouped_key, weight in layer_keys.items():
group_key = f"..unet_{block_type}_{block_num}_{grouped_key}"
if group_key in lora_key_groups["unet"]:
final_weight = weight * multiplier
if not remove_zero_weight_keys or final_weight != 0.0:
for target_key in lora_key_groups["unet"][group_key]:
if target_key.endswith(".alpha"):
final_weight = weight * multiplier
if not remove_zero_weight_keys or final_weight != 0.0:
adjusted_tensors[target_key] = lora_tensors[target_key]
else:
final_weight = weight * multiplier
if not remove_zero_weight_keys or final_weight != 0.0:
adjusted_tensors[target_key] = lora_tensors[target_key] * math.sqrt(final_weight)
try:
save_file(adjusted_tensors, output_path, metadata)
except Exception as e:
raise Exception(f"Error saving adjusted model: {e}")
def write_toml(lora_hierarchy, output_path):
try:
with open(output_path, "w") as f:
toml.dump(lora_hierarchy, f)
except Exception as e:
raise Exception(f"Error writing TOML file: {e}")
def main():
parser = argparse.ArgumentParser(description="Extract or adjust LoRA weights based on a TOML config.")
subparsers = parser.add_subparsers(dest="mode", help="Choose mode: 'extract' or 'adjust'")
# Extract mode
parser_extract = subparsers.add_parser("extract", help="Extract LoRA hierarchy to a TOML file")
parser_extract.add_argument("--lora_path", required=True, help="Path to the LoRA safetensors file")
parser_extract.add_argument("--output_path", required=True, help="Path to the output TOML file")
# Adjust mode
parser_adjust = subparsers.add_parser("adjust", help="Adjust LoRA weights based on a TOML config.")
parser_adjust.add_argument("--lora_path", required=True, help="Path to the LoRA safetensors file")
parser_adjust.add_argument("--toml_path", required=True, help="Path to the TOML config file")
parser_adjust.add_argument("--output_path", required=True, help="Path to the output safetensors file")
parser_adjust.add_argument("--multiplier", type=float, default=1.0, help="Global multiplier for the LoRA weights")
parser_adjust.add_argument("--remove_zero_weight_keys", action="store_true",
help="Remove keys with resulting weight of 0. Useful for reducing file size.")
args = parser.parse_args()
try:
if args.mode == "extract":
lora_tensors = load_file(args.lora_path)
lora_hierarchy = extract_lora_hierarchy(lora_tensors)
write_toml(lora_hierarchy, args.output_path)
print(f"Successfully extracted LoRA hierarchy to {args.output_path}")
elif args.mode == "adjust":
adjust_lora_weights(args.lora_path, args.toml_path, args.output_path, args.multiplier, args.remove_zero_weight_keys)
print(f"Successfully adjusted LoRA weights and saved to {args.output_path}")
else:
parser.print_help()
except Exception as e:
print(f"An error occurred: {e}")
if __name__ == "__main__":
main() |