|
|
|
|
|
|
|
import torch |
|
full_state_dict = torch.load("./pytorch_model.bin") |
|
full_state_dict = dict((".".join(k.split(".")[1:]), v) \ |
|
for k, v in full_state_dict.items()) |
|
|
|
def con_cat(kqv_dict): |
|
kqv_dict_keys = list(kqv_dict.keys()) |
|
if "weight" in kqv_dict_keys[0]: |
|
tmp = kqv_dict_keys[0].split(".")[3] |
|
c_dict_value = torch.cat([kqv_dict[kqv_dict_keys[0].replace(tmp, "q_proj")], |
|
kqv_dict[kqv_dict_keys[0].replace(tmp, "k_proj")], |
|
kqv_dict[kqv_dict_keys[0].replace(tmp, "v_proj")] |
|
]) |
|
c_dict_key = ".".join(kqv_dict_keys[0].split(".")[:3]+["in_proj_weight"]) |
|
|
|
return {f"encoder.{c_dict_key}":c_dict_value} |
|
|
|
|
|
if "bias" in kqv_dict_keys[0]: |
|
tmp = kqv_dict_keys[0].split(".")[3] |
|
c_dict_value = torch.cat([kqv_dict[kqv_dict_keys[0].replace(tmp, "q_proj")], |
|
kqv_dict[kqv_dict_keys[0].replace(tmp, "k_proj")], |
|
kqv_dict[kqv_dict_keys[0].replace(tmp, "v_proj")] |
|
]) |
|
c_dict_key = ".".join(kqv_dict_keys[0].split(".")[:3]+["in_proj_bias"]) |
|
|
|
return {f"encoder.{c_dict_key}":c_dict_value} |
|
|
|
|
|
mod_dict = {} |
|
|
|
for k, v in full_state_dict.items(): |
|
if "embedding" in k or "layer_norm" in k: |
|
mod_dict.update({f"embeddings.{k}": v}) |
|
|
|
|
|
for i in range(24): |
|
sd = dict((k, v) for k, v in full_state_dict.items() if f"layers.{i}" in k) |
|
kvq_weight = {} |
|
kvq_bias = {} |
|
for k, v in sd.items(): |
|
if "self_attn" in k and "out_proj" not in k: |
|
if "weight" in k: |
|
kvq_weight[k] = v |
|
if "bias" in k: |
|
kvq_bias[k] = v |
|
else: |
|
mod_dict[f"encoder.{k}"] = v |
|
|
|
mod_dict.update(con_cat(kvq_weight)) |
|
mod_dict.update(con_cat(kvq_bias)) |
|
|
|
|
|
for k, v in full_state_dict.items(): |
|
if "dense" in k: |
|
mod_dict.update({f"pooler.{k}":v}) |
|
|
|
|
|
for k, v in mod_dict.items(): |
|
print(k, v.size()) |
|
|
|
model_name = "ernie-m-base_pytorch" |
|
PATH = f"./{model_name}/pytorch_model.bin" |
|
torch.save(mod_dict, PATH) |