File size: 10,369 Bytes
82ea528
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
# For using the diffusers format weights
#  Based on the original ComfyUI function +
#  https://github.com/PixArt-alpha/PixArt-alpha/blob/master/tools/convert_pixart_alpha_to_diffusers.py
import torch

conversion_map_ms = [  # for multi_scale_train (MS)
    # Resolution
    ("csize_embedder.mlp.0.weight", "adaln_single.emb.resolution_embedder.linear_1.weight"),
    ("csize_embedder.mlp.0.bias", "adaln_single.emb.resolution_embedder.linear_1.bias"),
    ("csize_embedder.mlp.2.weight", "adaln_single.emb.resolution_embedder.linear_2.weight"),
    ("csize_embedder.mlp.2.bias", "adaln_single.emb.resolution_embedder.linear_2.bias"),
    # Aspect ratio
    ("ar_embedder.mlp.0.weight", "adaln_single.emb.aspect_ratio_embedder.linear_1.weight"),
    ("ar_embedder.mlp.0.bias", "adaln_single.emb.aspect_ratio_embedder.linear_1.bias"),
    ("ar_embedder.mlp.2.weight", "adaln_single.emb.aspect_ratio_embedder.linear_2.weight"),
    ("ar_embedder.mlp.2.bias", "adaln_single.emb.aspect_ratio_embedder.linear_2.bias"),
]


def get_depth(state_dict):
    return sum(key.endswith('.attn1.to_k.bias') for key in state_dict.keys())


def get_lora_depth(state_dict):
    return sum(key.endswith('.attn1.to_k.lora_A.weight') for key in state_dict.keys())


def get_conversion_map(state_dict):
    conversion_map = [  # main SD conversion map (PixArt reference, HF Diffusers)
        # Patch embeddings
        ("x_embedder.proj.weight", "pos_embed.proj.weight"),
        ("x_embedder.proj.bias", "pos_embed.proj.bias"),
        # Caption projection
        ("y_embedder.y_embedding", "caption_projection.y_embedding"),
        ("y_embedder.y_proj.fc1.weight", "caption_projection.linear_1.weight"),
        ("y_embedder.y_proj.fc1.bias", "caption_projection.linear_1.bias"),
        ("y_embedder.y_proj.fc2.weight", "caption_projection.linear_2.weight"),
        ("y_embedder.y_proj.fc2.bias", "caption_projection.linear_2.bias"),
        # AdaLN-single LN
        ("t_embedder.mlp.0.weight", "adaln_single.emb.timestep_embedder.linear_1.weight"),
        ("t_embedder.mlp.0.bias", "adaln_single.emb.timestep_embedder.linear_1.bias"),
        ("t_embedder.mlp.2.weight", "adaln_single.emb.timestep_embedder.linear_2.weight"),
        ("t_embedder.mlp.2.bias", "adaln_single.emb.timestep_embedder.linear_2.bias"),
        # Shared norm
        ("t_block.1.weight", "adaln_single.linear.weight"),
        ("t_block.1.bias", "adaln_single.linear.bias"),
        # Final block
        ("final_layer.linear.weight", "proj_out.weight"),
        ("final_layer.linear.bias", "proj_out.bias"),
        ("final_layer.scale_shift_table", "scale_shift_table"),
    ]

    # Add actual transformer blocks
    for depth in range(get_depth(state_dict)):
        # Transformer blocks
        conversion_map += [
            (f"blocks.{depth}.scale_shift_table", f"transformer_blocks.{depth}.scale_shift_table"),
            # Projection
            (f"blocks.{depth}.attn.proj.weight", f"transformer_blocks.{depth}.attn1.to_out.0.weight"),
            (f"blocks.{depth}.attn.proj.bias", f"transformer_blocks.{depth}.attn1.to_out.0.bias"),
            # Feed-forward
            (f"blocks.{depth}.mlp.fc1.weight", f"transformer_blocks.{depth}.ff.net.0.proj.weight"),
            (f"blocks.{depth}.mlp.fc1.bias", f"transformer_blocks.{depth}.ff.net.0.proj.bias"),
            (f"blocks.{depth}.mlp.fc2.weight", f"transformer_blocks.{depth}.ff.net.2.weight"),
            (f"blocks.{depth}.mlp.fc2.bias", f"transformer_blocks.{depth}.ff.net.2.bias"),
            # Cross-attention (proj)
            (f"blocks.{depth}.cross_attn.proj.weight", f"transformer_blocks.{depth}.attn2.to_out.0.weight"),
            (f"blocks.{depth}.cross_attn.proj.bias", f"transformer_blocks.{depth}.attn2.to_out.0.bias"),
        ]
    return conversion_map


def find_prefix(state_dict, target_key):
    prefix = ""
    for k in state_dict.keys():
        if k.endswith(target_key):
            prefix = k.split(target_key)[0]
            break
    return prefix


def convert_state_dict(state_dict):
    if "adaln_single.emb.resolution_embedder.linear_1.weight" in state_dict.keys():
        cmap = get_conversion_map(state_dict) + conversion_map_ms
    else:
        cmap = get_conversion_map(state_dict)

    missing = [k for k, v in cmap if v not in state_dict]
    new_state_dict = {k: state_dict[v] for k, v in cmap if k not in missing}
    matched = list(v for k, v in cmap if v in state_dict.keys())

    for depth in range(get_depth(state_dict)):
        for wb in ["weight", "bias"]:
            # Self Attention
            key = lambda a: f"transformer_blocks.{depth}.attn1.to_{a}.{wb}"
            new_state_dict[f"blocks.{depth}.attn.qkv.{wb}"] = torch.cat((
                state_dict[key('q')], state_dict[key('k')], state_dict[key('v')]
            ), dim=0)
            matched += [key('q'), key('k'), key('v')]

            # Cross-attention (linear)
            key = lambda a: f"transformer_blocks.{depth}.attn2.to_{a}.{wb}"
            new_state_dict[f"blocks.{depth}.cross_attn.q_linear.{wb}"] = state_dict[key('q')]
            new_state_dict[f"blocks.{depth}.cross_attn.kv_linear.{wb}"] = torch.cat((
                state_dict[key('k')], state_dict[key('v')]
            ), dim=0)
            matched += [key('q'), key('k'), key('v')]

    if len(matched) < len(state_dict):
        print(f"PixArt: UNET conversion has leftover keys! ({len(matched)} vs {len(state_dict)})")
        print(list(set(state_dict.keys()) - set(matched)))

    if len(missing) > 0:
        print(f"PixArt: UNET conversion has missing keys!")
        print(missing)

    return new_state_dict


# Same as above but for LoRA weights:
def convert_lora_state_dict(state_dict, peft=True):
    # koyha
    rep_ak = lambda x: x.replace(".weight", ".lora_down.weight")
    rep_bk = lambda x: x.replace(".weight", ".lora_up.weight")
    rep_pk = lambda x: x.replace(".weight", ".alpha")
    if peft:  # peft
        rep_ap = lambda x: x.replace(".weight", ".lora_A.weight")
        rep_bp = lambda x: x.replace(".weight", ".lora_B.weight")
        rep_pp = lambda x: x.replace(".weight", ".alpha")

        prefix = find_prefix(state_dict, "adaln_single.linear.lora_A.weight")
        state_dict = {k[len(prefix):]: v for k, v in state_dict.items()}
    else:  # OneTrainer
        rep_ap = lambda x: x.replace(".", "_")[:-7] + ".lora_down.weight"
        rep_bp = lambda x: x.replace(".", "_")[:-7] + ".lora_up.weight"
        rep_pp = lambda x: x.replace(".", "_")[:-7] + ".alpha"

        prefix = "lora_transformer_"
        t5_marker = "lora_te_encoder"
        t5_keys = []
        for key in list(state_dict.keys()):
            if key.startswith(prefix):
                state_dict[key[len(prefix):]] = state_dict.pop(key)
            elif t5_marker in key:
                t5_keys.append(state_dict.pop(key))
        if len(t5_keys) > 0:
            print(f"Text Encoder not supported for PixArt LoRA, ignoring {len(t5_keys)} keys")

    cmap = []
    cmap_unet = get_conversion_map(state_dict) + conversion_map_ms  # todo: 512 model
    for k, v in cmap_unet:
        if v.endswith(".weight"):
            cmap.append((rep_ak(k), rep_ap(v)))
            cmap.append((rep_bk(k), rep_bp(v)))
            if not peft:
                cmap.append((rep_pk(k), rep_pp(v)))

    missing = [k for k, v in cmap if v not in state_dict]
    new_state_dict = {k: state_dict[v] for k, v in cmap if k not in missing}
    matched = list(v for k, v in cmap if v in state_dict.keys())

    lora_depth = get_lora_depth(state_dict)
    for fp, fk in ((rep_ap, rep_ak), (rep_bp, rep_bk)):
        for depth in range(lora_depth):
            # Self Attention
            key = lambda a: fp(f"transformer_blocks.{depth}.attn1.to_{a}.weight")
            new_state_dict[fk(f"blocks.{depth}.attn.qkv.weight")] = torch.cat((
                state_dict[key('q')], state_dict[key('k')], state_dict[key('v')]
            ), dim=0)

            matched += [key('q'), key('k'), key('v')]
            if not peft:
                akey = lambda a: rep_pp(f"transformer_blocks.{depth}.attn1.to_{a}.weight")
                new_state_dict[rep_pk((f"blocks.{depth}.attn.qkv.weight"))] = state_dict[akey("q")]
                matched += [akey('q'), akey('k'), akey('v')]

            # Self Attention projection?
            key = lambda a: fp(f"transformer_blocks.{depth}.attn1.to_{a}.weight")
            new_state_dict[fk(f"blocks.{depth}.attn.proj.weight")] = state_dict[key('out.0')]
            matched += [key('out.0')]

            # Cross-attention (linear)
            key = lambda a: fp(f"transformer_blocks.{depth}.attn2.to_{a}.weight")
            new_state_dict[fk(f"blocks.{depth}.cross_attn.q_linear.weight")] = state_dict[key('q')]
            new_state_dict[fk(f"blocks.{depth}.cross_attn.kv_linear.weight")] = torch.cat((
                state_dict[key('k')], state_dict[key('v')]
            ), dim=0)
            matched += [key('q'), key('k'), key('v')]
            if not peft:
                akey = lambda a: rep_pp(f"transformer_blocks.{depth}.attn2.to_{a}.weight")
                new_state_dict[rep_pk((f"blocks.{depth}.cross_attn.q_linear.weight"))] = state_dict[akey("q")]
                new_state_dict[rep_pk((f"blocks.{depth}.cross_attn.kv_linear.weight"))] = state_dict[akey("k")]
                matched += [akey('q'), akey('k'), akey('v')]

            # Cross Attention projection?
            key = lambda a: fp(f"transformer_blocks.{depth}.attn2.to_{a}.weight")
            new_state_dict[fk(f"blocks.{depth}.cross_attn.proj.weight")] = state_dict[key('out.0')]
            matched += [key('out.0')]

            key = fp(f"transformer_blocks.{depth}.ff.net.0.proj.weight")
            new_state_dict[fk(f"blocks.{depth}.mlp.fc1.weight")] = state_dict[key]
            matched += [key]

            key = fp(f"transformer_blocks.{depth}.ff.net.2.weight")
            new_state_dict[fk(f"blocks.{depth}.mlp.fc2.weight")] = state_dict[key]
            matched += [key]

    if len(matched) < len(state_dict):
        print(f"PixArt: LoRA conversion has leftover keys! ({len(matched)} vs {len(state_dict)})")
        print(list(set(state_dict.keys()) - set(matched)))

    if len(missing) > 0:
        print(f"PixArt: LoRA conversion has missing keys! (probably)")
        print(missing)

    return new_state_dict