File size: 9,524 Bytes
7f5a1b6
 
 
 
 
 
a77578c
7f5a1b6
 
c1a5584
 
 
7f5a1b6
c1a5584
 
 
7f5a1b6
c1a5584
 
 
7f5a1b6
c1a5584
 
 
7f5a1b6
c1a5584
7f5a1b6
 
c1a5584
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7f5a1b6
 
 
c1a5584
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7f5a1b6
 
 
c1a5584
 
 
 
 
7f5a1b6
 
 
c1a5584
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7f5a1b6
 
c1a5584
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
import argparse
import os
from safetensors.torch import load_file, save_file
import toml
import re
from safetensors import safe_open
import math

def parse_key(key):
    match = re.match(r"lora_unet_(input|output|up|down)_blocks_(\d+(?:_\d+)?)_(.+)\.(?:alpha|lora_(?:down|up)\.weight)", key)
    if match:
        return "unet", match.group(1) + "_blocks", match.group(2), match.group(3)

    match = re.match(r"lora_unet_(mid_block)_(resnets|attentions)_(\d+)_(.+)\.(?:alpha|lora_(?:down|up)\.weight)", key)
    if match:
        return "unet", match.group(1), f"{match.group(2)}_{match.group(3)}", match.group(4)

    match = re.match(r"lora_unet_(middle_block)_(\d+)_(.+)\.(?:alpha|lora_(?:down|up)\.weight)", key)
    if match:
        return "unet", match.group(1), match.group(2), match.group(3)

    match = re.match(r"lora_te\d+_text_model_encoder_(.+)\.(?:alpha|lora_(?:down|up)\.weight)", key)
    if match:
        return "text_encoder", "encoder_layers", match.group(1).split("_")[0], "_".join(match.group(1).split("_")[1:])

    return None, None, None, None

def extract_lora_hierarchy(lora_tensors, mode="extract"):
    lora_hierarchy = {}
    lora_key_groups = {"unet": {}, "text_encoder": {}} if mode == "adjust" else None

    for key in lora_tensors:
        if key.startswith("lora_unet_"):
            model_type, block_type, block_num, layer_key = parse_key(key)

            if model_type and block_type and layer_key:
                parts = layer_key.split("_")
                if "transformer_blocks" in layer_key:
                    grouped_key = "_".join(parts[:3] + [parts[3] if len(parts) > 5 else ""])
                elif "attentions" in layer_key:
                    grouped_key = "_".join(parts[:3] + [parts[3] if len(parts) > 5 else ""])
                elif "resnets" in layer_key:
                    grouped_key = "_".join(parts[:3])
                else:
                    grouped_key = layer_key

                if model_type not in lora_hierarchy:
                    lora_hierarchy[model_type] = {}
                if block_type not in lora_hierarchy[model_type]:
                    lora_hierarchy[model_type][block_type] = {}
                if block_num not in lora_hierarchy[model_type][block_type]:
                    lora_hierarchy[model_type][block_type][block_num] = {}
                lora_hierarchy[model_type][block_type][block_num][grouped_key] = 1.0

                if mode == "adjust":
                    group_key = f"..unet_{block_type}_{block_num}_{grouped_key}"
                    if group_key not in lora_key_groups["unet"]:
                        lora_key_groups["unet"][group_key] = []
                    lora_key_groups["unet"][group_key].append(key)

        elif key.startswith("lora_te"):
            match = re.match(r"(lora_te\d+)_text_model_encoder_layers_(\d+)_(.+)\.(?:alpha|lora_(?:down|up)\.weight)", key)
            if match:
                model_section = match.group(1)
                block_type = "encoder"
                block_num = match.group(2)
                layer_key = match.group(3)

                grouped_key = f"layers_{block_num}__{layer_key}"

                if model_section not in lora_hierarchy:
                    lora_hierarchy[model_section] = {}
                if block_type not in lora_hierarchy[model_section]:
                    lora_hierarchy[model_section][block_type] = {}
                lora_hierarchy[model_section][block_type][grouped_key] = 1.0

                if mode == "adjust":
                    group_key = f"..{model_section}_{block_num}_{layer_key}"
                    lora_key_groups["text_encoder"][group_key] = [key]

    return lora_hierarchy if mode == "extract" else lora_key_groups


def adjust_lora_weights(lora_path, toml_path, output_path, multiplier=1.0, remove_zero_weight_keys=True):
    try:
        lora_tensors = load_file(lora_path)
        with safe_open(lora_path, framework="pt") as f:
            metadata = f.metadata()
    except Exception as e:
        raise Exception(f"Error loading LoRA model: {e}")

    try:
        with open(toml_path, "r") as f:
            lora_config = toml.load(f)
    except Exception as e:
        raise Exception(f"Error loading TOML file: {e}")


    lora_key_groups = extract_lora_hierarchy(lora_tensors, mode="adjust")
    adjusted_tensors = {}

    for model_section, model_config in lora_config.items():
        if model_section.startswith("lora_te"):
            for block_type, layers in model_config.items():
                for layer_key, weight in layers.items():
                    block_num, layer_name = layer_key.replace("layers_", "").split("__")
                    group_key = f"..{model_section}_{block_num}_{layer_name}"
                    if group_key in lora_key_groups["text_encoder"]:
                        final_weight = weight * multiplier
                        if not remove_zero_weight_keys or final_weight != 0.0:
                            for target_key in lora_key_groups["text_encoder"][group_key]:
                                if target_key.endswith(".alpha"):
                                  final_weight = weight * multiplier
                                  if not remove_zero_weight_keys or final_weight != 0.0:
                                    adjusted_tensors[target_key] = lora_tensors[target_key]
                                else:
                                  final_weight = weight * multiplier
                                  if not remove_zero_weight_keys or final_weight != 0.0:
                                    adjusted_tensors[target_key] = lora_tensors[target_key] * math.sqrt(final_weight)

        else:  # unet
            for block_type, block_nums in model_config.items():
                for block_num, layer_keys in block_nums.items():
                    for grouped_key, weight in layer_keys.items():
                        group_key = f"..unet_{block_type}_{block_num}_{grouped_key}"
                        if group_key in lora_key_groups["unet"]:
                            final_weight = weight * multiplier
                            if not remove_zero_weight_keys or final_weight != 0.0:
                                for target_key in lora_key_groups["unet"][group_key]:
                                    if target_key.endswith(".alpha"):
                                      final_weight = weight * multiplier
                                      if not remove_zero_weight_keys or final_weight != 0.0:
                                        adjusted_tensors[target_key] = lora_tensors[target_key]
                                    else:
                                      final_weight = weight * multiplier
                                      if not remove_zero_weight_keys or final_weight != 0.0:
                                        adjusted_tensors[target_key] = lora_tensors[target_key] * math.sqrt(final_weight)



    try:
        save_file(adjusted_tensors, output_path, metadata)
    except Exception as e:
        raise Exception(f"Error saving adjusted model: {e}")


def write_toml(lora_hierarchy, output_path):
    try:
        with open(output_path, "w") as f:
            toml.dump(lora_hierarchy, f)
    except Exception as e:
        raise Exception(f"Error writing TOML file: {e}")


def main():
    parser = argparse.ArgumentParser(description="Extract or adjust LoRA weights based on a TOML config.")
    subparsers = parser.add_subparsers(dest="mode", help="Choose mode: 'extract' or 'adjust'")

    # Extract mode
    parser_extract = subparsers.add_parser("extract", help="Extract LoRA hierarchy to a TOML file")
    parser_extract.add_argument("--lora_path", required=True, help="Path to the LoRA safetensors file")
    parser_extract.add_argument("--output_path", required=True, help="Path to the output TOML file")

    # Adjust mode
    parser_adjust = subparsers.add_parser("adjust", help="Adjust LoRA weights based on a TOML config.")
    parser_adjust.add_argument("--lora_path", required=True, help="Path to the LoRA safetensors file")
    parser_adjust.add_argument("--toml_path", required=True, help="Path to the TOML config file")
    parser_adjust.add_argument("--output_path", required=True, help="Path to the output safetensors file")
    parser_adjust.add_argument("--multiplier", type=float, default=1.0, help="Global multiplier for the LoRA weights")
    parser_adjust.add_argument("--remove_zero_weight_keys", action="store_true",
                            help="Remove keys with resulting weight of 0. Useful for reducing file size.")

    args = parser.parse_args()

    try:
        if args.mode == "extract":
            lora_tensors = load_file(args.lora_path)
            lora_hierarchy = extract_lora_hierarchy(lora_tensors) 
            write_toml(lora_hierarchy, args.output_path)
            print(f"Successfully extracted LoRA hierarchy to {args.output_path}")

        elif args.mode == "adjust":
            adjust_lora_weights(args.lora_path, args.toml_path, args.output_path, args.multiplier, args.remove_zero_weight_keys)
            print(f"Successfully adjusted LoRA weights and saved to {args.output_path}")

        else:
            parser.print_help()

    except Exception as e:
        print(f"An error occurred: {e}")

if __name__ == "__main__":
    main()