|
|
|
|
|
|
|
import torch |
|
|
|
conversion_map_ms = [ |
|
|
|
("csize_embedder.mlp.0.weight", "adaln_single.emb.resolution_embedder.linear_1.weight"), |
|
("csize_embedder.mlp.0.bias", "adaln_single.emb.resolution_embedder.linear_1.bias"), |
|
("csize_embedder.mlp.2.weight", "adaln_single.emb.resolution_embedder.linear_2.weight"), |
|
("csize_embedder.mlp.2.bias", "adaln_single.emb.resolution_embedder.linear_2.bias"), |
|
|
|
("ar_embedder.mlp.0.weight", "adaln_single.emb.aspect_ratio_embedder.linear_1.weight"), |
|
("ar_embedder.mlp.0.bias", "adaln_single.emb.aspect_ratio_embedder.linear_1.bias"), |
|
("ar_embedder.mlp.2.weight", "adaln_single.emb.aspect_ratio_embedder.linear_2.weight"), |
|
("ar_embedder.mlp.2.bias", "adaln_single.emb.aspect_ratio_embedder.linear_2.bias"), |
|
] |
|
|
|
|
|
def get_depth(state_dict): |
|
return sum(key.endswith('.attn1.to_k.bias') for key in state_dict.keys()) |
|
|
|
|
|
def get_lora_depth(state_dict): |
|
return sum(key.endswith('.attn1.to_k.lora_A.weight') for key in state_dict.keys()) |
|
|
|
|
|
def get_conversion_map(state_dict): |
|
conversion_map = [ |
|
|
|
("x_embedder.proj.weight", "pos_embed.proj.weight"), |
|
("x_embedder.proj.bias", "pos_embed.proj.bias"), |
|
|
|
("y_embedder.y_embedding", "caption_projection.y_embedding"), |
|
("y_embedder.y_proj.fc1.weight", "caption_projection.linear_1.weight"), |
|
("y_embedder.y_proj.fc1.bias", "caption_projection.linear_1.bias"), |
|
("y_embedder.y_proj.fc2.weight", "caption_projection.linear_2.weight"), |
|
("y_embedder.y_proj.fc2.bias", "caption_projection.linear_2.bias"), |
|
|
|
("t_embedder.mlp.0.weight", "adaln_single.emb.timestep_embedder.linear_1.weight"), |
|
("t_embedder.mlp.0.bias", "adaln_single.emb.timestep_embedder.linear_1.bias"), |
|
("t_embedder.mlp.2.weight", "adaln_single.emb.timestep_embedder.linear_2.weight"), |
|
("t_embedder.mlp.2.bias", "adaln_single.emb.timestep_embedder.linear_2.bias"), |
|
|
|
("t_block.1.weight", "adaln_single.linear.weight"), |
|
("t_block.1.bias", "adaln_single.linear.bias"), |
|
|
|
("final_layer.linear.weight", "proj_out.weight"), |
|
("final_layer.linear.bias", "proj_out.bias"), |
|
("final_layer.scale_shift_table", "scale_shift_table"), |
|
] |
|
|
|
|
|
for depth in range(get_depth(state_dict)): |
|
|
|
conversion_map += [ |
|
(f"blocks.{depth}.scale_shift_table", f"transformer_blocks.{depth}.scale_shift_table"), |
|
|
|
(f"blocks.{depth}.attn.proj.weight", f"transformer_blocks.{depth}.attn1.to_out.0.weight"), |
|
(f"blocks.{depth}.attn.proj.bias", f"transformer_blocks.{depth}.attn1.to_out.0.bias"), |
|
|
|
(f"blocks.{depth}.mlp.fc1.weight", f"transformer_blocks.{depth}.ff.net.0.proj.weight"), |
|
(f"blocks.{depth}.mlp.fc1.bias", f"transformer_blocks.{depth}.ff.net.0.proj.bias"), |
|
(f"blocks.{depth}.mlp.fc2.weight", f"transformer_blocks.{depth}.ff.net.2.weight"), |
|
(f"blocks.{depth}.mlp.fc2.bias", f"transformer_blocks.{depth}.ff.net.2.bias"), |
|
|
|
(f"blocks.{depth}.cross_attn.proj.weight", f"transformer_blocks.{depth}.attn2.to_out.0.weight"), |
|
(f"blocks.{depth}.cross_attn.proj.bias", f"transformer_blocks.{depth}.attn2.to_out.0.bias"), |
|
] |
|
return conversion_map |
|
|
|
|
|
def find_prefix(state_dict, target_key): |
|
prefix = "" |
|
for k in state_dict.keys(): |
|
if k.endswith(target_key): |
|
prefix = k.split(target_key)[0] |
|
break |
|
return prefix |
|
|
|
|
|
def convert_state_dict(state_dict): |
|
if "adaln_single.emb.resolution_embedder.linear_1.weight" in state_dict.keys(): |
|
cmap = get_conversion_map(state_dict) + conversion_map_ms |
|
else: |
|
cmap = get_conversion_map(state_dict) |
|
|
|
missing = [k for k, v in cmap if v not in state_dict] |
|
new_state_dict = {k: state_dict[v] for k, v in cmap if k not in missing} |
|
matched = list(v for k, v in cmap if v in state_dict.keys()) |
|
|
|
for depth in range(get_depth(state_dict)): |
|
for wb in ["weight", "bias"]: |
|
|
|
key = lambda a: f"transformer_blocks.{depth}.attn1.to_{a}.{wb}" |
|
new_state_dict[f"blocks.{depth}.attn.qkv.{wb}"] = torch.cat(( |
|
state_dict[key('q')], state_dict[key('k')], state_dict[key('v')] |
|
), dim=0) |
|
matched += [key('q'), key('k'), key('v')] |
|
|
|
|
|
key = lambda a: f"transformer_blocks.{depth}.attn2.to_{a}.{wb}" |
|
new_state_dict[f"blocks.{depth}.cross_attn.q_linear.{wb}"] = state_dict[key('q')] |
|
new_state_dict[f"blocks.{depth}.cross_attn.kv_linear.{wb}"] = torch.cat(( |
|
state_dict[key('k')], state_dict[key('v')] |
|
), dim=0) |
|
matched += [key('q'), key('k'), key('v')] |
|
|
|
if len(matched) < len(state_dict): |
|
print(f"PixArt: UNET conversion has leftover keys! ({len(matched)} vs {len(state_dict)})") |
|
print(list(set(state_dict.keys()) - set(matched))) |
|
|
|
if len(missing) > 0: |
|
print(f"PixArt: UNET conversion has missing keys!") |
|
print(missing) |
|
|
|
return new_state_dict |
|
|
|
|
|
|
|
def convert_lora_state_dict(state_dict, peft=True): |
|
|
|
rep_ak = lambda x: x.replace(".weight", ".lora_down.weight") |
|
rep_bk = lambda x: x.replace(".weight", ".lora_up.weight") |
|
rep_pk = lambda x: x.replace(".weight", ".alpha") |
|
if peft: |
|
rep_ap = lambda x: x.replace(".weight", ".lora_A.weight") |
|
rep_bp = lambda x: x.replace(".weight", ".lora_B.weight") |
|
rep_pp = lambda x: x.replace(".weight", ".alpha") |
|
|
|
prefix = find_prefix(state_dict, "adaln_single.linear.lora_A.weight") |
|
state_dict = {k[len(prefix):]: v for k, v in state_dict.items()} |
|
else: |
|
rep_ap = lambda x: x.replace(".", "_")[:-7] + ".lora_down.weight" |
|
rep_bp = lambda x: x.replace(".", "_")[:-7] + ".lora_up.weight" |
|
rep_pp = lambda x: x.replace(".", "_")[:-7] + ".alpha" |
|
|
|
prefix = "lora_transformer_" |
|
t5_marker = "lora_te_encoder" |
|
t5_keys = [] |
|
for key in list(state_dict.keys()): |
|
if key.startswith(prefix): |
|
state_dict[key[len(prefix):]] = state_dict.pop(key) |
|
elif t5_marker in key: |
|
t5_keys.append(state_dict.pop(key)) |
|
if len(t5_keys) > 0: |
|
print(f"Text Encoder not supported for PixArt LoRA, ignoring {len(t5_keys)} keys") |
|
|
|
cmap = [] |
|
cmap_unet = get_conversion_map(state_dict) + conversion_map_ms |
|
for k, v in cmap_unet: |
|
if v.endswith(".weight"): |
|
cmap.append((rep_ak(k), rep_ap(v))) |
|
cmap.append((rep_bk(k), rep_bp(v))) |
|
if not peft: |
|
cmap.append((rep_pk(k), rep_pp(v))) |
|
|
|
missing = [k for k, v in cmap if v not in state_dict] |
|
new_state_dict = {k: state_dict[v] for k, v in cmap if k not in missing} |
|
matched = list(v for k, v in cmap if v in state_dict.keys()) |
|
|
|
lora_depth = get_lora_depth(state_dict) |
|
for fp, fk in ((rep_ap, rep_ak), (rep_bp, rep_bk)): |
|
for depth in range(lora_depth): |
|
|
|
key = lambda a: fp(f"transformer_blocks.{depth}.attn1.to_{a}.weight") |
|
new_state_dict[fk(f"blocks.{depth}.attn.qkv.weight")] = torch.cat(( |
|
state_dict[key('q')], state_dict[key('k')], state_dict[key('v')] |
|
), dim=0) |
|
|
|
matched += [key('q'), key('k'), key('v')] |
|
if not peft: |
|
akey = lambda a: rep_pp(f"transformer_blocks.{depth}.attn1.to_{a}.weight") |
|
new_state_dict[rep_pk((f"blocks.{depth}.attn.qkv.weight"))] = state_dict[akey("q")] |
|
matched += [akey('q'), akey('k'), akey('v')] |
|
|
|
|
|
key = lambda a: fp(f"transformer_blocks.{depth}.attn1.to_{a}.weight") |
|
new_state_dict[fk(f"blocks.{depth}.attn.proj.weight")] = state_dict[key('out.0')] |
|
matched += [key('out.0')] |
|
|
|
|
|
key = lambda a: fp(f"transformer_blocks.{depth}.attn2.to_{a}.weight") |
|
new_state_dict[fk(f"blocks.{depth}.cross_attn.q_linear.weight")] = state_dict[key('q')] |
|
new_state_dict[fk(f"blocks.{depth}.cross_attn.kv_linear.weight")] = torch.cat(( |
|
state_dict[key('k')], state_dict[key('v')] |
|
), dim=0) |
|
matched += [key('q'), key('k'), key('v')] |
|
if not peft: |
|
akey = lambda a: rep_pp(f"transformer_blocks.{depth}.attn2.to_{a}.weight") |
|
new_state_dict[rep_pk((f"blocks.{depth}.cross_attn.q_linear.weight"))] = state_dict[akey("q")] |
|
new_state_dict[rep_pk((f"blocks.{depth}.cross_attn.kv_linear.weight"))] = state_dict[akey("k")] |
|
matched += [akey('q'), akey('k'), akey('v')] |
|
|
|
|
|
key = lambda a: fp(f"transformer_blocks.{depth}.attn2.to_{a}.weight") |
|
new_state_dict[fk(f"blocks.{depth}.cross_attn.proj.weight")] = state_dict[key('out.0')] |
|
matched += [key('out.0')] |
|
|
|
key = fp(f"transformer_blocks.{depth}.ff.net.0.proj.weight") |
|
new_state_dict[fk(f"blocks.{depth}.mlp.fc1.weight")] = state_dict[key] |
|
matched += [key] |
|
|
|
key = fp(f"transformer_blocks.{depth}.ff.net.2.weight") |
|
new_state_dict[fk(f"blocks.{depth}.mlp.fc2.weight")] = state_dict[key] |
|
matched += [key] |
|
|
|
if len(matched) < len(state_dict): |
|
print(f"PixArt: LoRA conversion has leftover keys! ({len(matched)} vs {len(state_dict)})") |
|
print(list(set(state_dict.keys()) - set(matched))) |
|
|
|
if len(missing) > 0: |
|
print(f"PixArt: LoRA conversion has missing keys! (probably)") |
|
print(missing) |
|
|
|
return new_state_dict |