File size: 3,220 Bytes
7134ebe |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 |
---
license: apache-2.0
language:
- ko
tags:
- rwkv
- KoRWKV
---
# KoRWKV
[RWKV-Runner](https://github.com/josStorer/RWKV-Runner)์์ ์ฌ์ฉํ๊ธฐ ์ํด ๋ณํํ ๋ชจ๋ธ ํ์ผ
- [beomi/KoAlpaca-KoRWKV-6B](https://huggingface.co/beomi/KoAlpaca-KoRWKV-6B)
- [beomi/KoRWKV-6B](https://huggingface.co/beomi/KoRWKV-6B)
```py
import re
import torch
from transformers import RwkvForCausalLM
def convert_state_dict(state_dict):
state_dict_keys = list(state_dict.keys())
for name in state_dict_keys:
weight = state_dict.pop(name)
# emb -> embedding
if name.startswith("emb."):
name = name.replace("emb.", "embeddings.")
# ln_0 -> pre_ln (only present at block 0)
if name.startswith("blocks.0.ln0"):
name = name.replace("blocks.0.ln0", "blocks.0.pre_ln")
# att -> attention
name = re.sub(r"blocks\.(\d+)\.att", r"blocks.\1.attention", name)
# ffn -> feed_forward
name = re.sub(r"blocks\.(\d+)\.ffn", r"blocks.\1.feed_forward", name)
# time_mix_k -> time_mix_key and reshape
if name.endswith(".time_mix_k"):
name = name.replace(".time_mix_k", ".time_mix_key")
# time_mix_v -> time_mix_value and reshape
if name.endswith(".time_mix_v"):
name = name.replace(".time_mix_v", ".time_mix_value")
# time_mix_r -> time_mix_key and reshape
if name.endswith(".time_mix_r"):
name = name.replace(".time_mix_r", ".time_mix_receptance")
if name != "head.weight":
name = "rwkv." + name
state_dict[name] = weight
return state_dict
def revert_state_dict(state_dict):
state_dict_keys = list(state_dict.keys())
for name in state_dict_keys:
weight = state_dict.pop(name)
name = name.removeprefix("rwkv.")
# emb -> embedding
if name.startswith("embeddings."):
name = name.replace("embeddings.", "emb.")
# ln_0 -> pre_ln (only present at block 0)
if name.startswith("blocks.0.pre_ln"):
name = name.replace("blocks.0.pre_ln", "blocks.0.ln0")
# att -> attention
name = re.sub(r"blocks\.(\d+)\.attention", r"blocks.\1.att", name)
# ffn -> feed_forward
name = re.sub(r"blocks\.(\d+)\.feed_forward", r"blocks.\1.ffn", name)
# time_mix_k -> time_mix_key and reshape
if name.endswith(".time_mix_key"):
name = name.replace(".time_mix_key", ".time_mix_k")
# time_mix_v -> time_mix_value and reshape
if name.endswith(".time_mix_value"):
name = name.replace(".time_mix_value", ".time_mix_v")
# time_mix_r -> time_mix_key and reshape
if name.endswith(".time_mix_receptance"):
name = name.replace(".time_mix_receptance", ".time_mix_r")
state_dict[name] = weight
return state_dict
if __name__ == "__main__":
# repo = "beomi/KoRWKV-6B"
repo = "beomi/KoAlpaca-KoRWKV-6B"
model = RwkvForCausalLM.from_pretrained(repo, torch_dtype=torch.bfloat16)
state_dict = model.state_dict()
converted = revert_state_dict(state_dict)
name = repo.split("/")[-1] + ".bf16.pth"
torch.save(converted, name)
```
|