|
--- |
|
license: apache-2.0 |
|
language: |
|
- ko |
|
tags: |
|
- rwkv |
|
- KoRWKV |
|
--- |
|
|
|
# KoRWKV |
|
|
|
[RWKV-Runner](https://github.com/josStorer/RWKV-Runner)์์ ์ฌ์ฉํ๊ธฐ ์ํด ๋ณํํ ๋ชจ๋ธ ํ์ผ |
|
|
|
- [beomi/KoAlpaca-KoRWKV-6B](https://huggingface.co/beomi/KoAlpaca-KoRWKV-6B) |
|
- [beomi/KoRWKV-6B](https://huggingface.co/beomi/KoRWKV-6B) |
|
|
|
```py |
|
import re |
|
|
|
import torch |
|
|
|
from transformers import RwkvForCausalLM |
|
|
|
def convert_state_dict(state_dict): |
|
state_dict_keys = list(state_dict.keys()) |
|
for name in state_dict_keys: |
|
weight = state_dict.pop(name) |
|
# emb -> embedding |
|
if name.startswith("emb."): |
|
name = name.replace("emb.", "embeddings.") |
|
# ln_0 -> pre_ln (only present at block 0) |
|
if name.startswith("blocks.0.ln0"): |
|
name = name.replace("blocks.0.ln0", "blocks.0.pre_ln") |
|
# att -> attention |
|
name = re.sub(r"blocks\.(\d+)\.att", r"blocks.\1.attention", name) |
|
# ffn -> feed_forward |
|
name = re.sub(r"blocks\.(\d+)\.ffn", r"blocks.\1.feed_forward", name) |
|
# time_mix_k -> time_mix_key and reshape |
|
if name.endswith(".time_mix_k"): |
|
name = name.replace(".time_mix_k", ".time_mix_key") |
|
# time_mix_v -> time_mix_value and reshape |
|
if name.endswith(".time_mix_v"): |
|
name = name.replace(".time_mix_v", ".time_mix_value") |
|
# time_mix_r -> time_mix_key and reshape |
|
if name.endswith(".time_mix_r"): |
|
name = name.replace(".time_mix_r", ".time_mix_receptance") |
|
|
|
if name != "head.weight": |
|
name = "rwkv." + name |
|
|
|
state_dict[name] = weight |
|
return state_dict |
|
|
|
|
|
def revert_state_dict(state_dict): |
|
state_dict_keys = list(state_dict.keys()) |
|
for name in state_dict_keys: |
|
weight = state_dict.pop(name) |
|
name = name.removeprefix("rwkv.") |
|
|
|
# emb -> embedding |
|
if name.startswith("embeddings."): |
|
name = name.replace("embeddings.", "emb.") |
|
# ln_0 -> pre_ln (only present at block 0) |
|
if name.startswith("blocks.0.pre_ln"): |
|
name = name.replace("blocks.0.pre_ln", "blocks.0.ln0") |
|
# att -> attention |
|
name = re.sub(r"blocks\.(\d+)\.attention", r"blocks.\1.att", name) |
|
# ffn -> feed_forward |
|
name = re.sub(r"blocks\.(\d+)\.feed_forward", r"blocks.\1.ffn", name) |
|
# time_mix_k -> time_mix_key and reshape |
|
if name.endswith(".time_mix_key"): |
|
name = name.replace(".time_mix_key", ".time_mix_k") |
|
# time_mix_v -> time_mix_value and reshape |
|
if name.endswith(".time_mix_value"): |
|
name = name.replace(".time_mix_value", ".time_mix_v") |
|
# time_mix_r -> time_mix_key and reshape |
|
if name.endswith(".time_mix_receptance"): |
|
name = name.replace(".time_mix_receptance", ".time_mix_r") |
|
|
|
state_dict[name] = weight |
|
return state_dict |
|
|
|
|
|
if __name__ == "__main__": |
|
# repo = "beomi/KoRWKV-6B" |
|
repo = "beomi/KoAlpaca-KoRWKV-6B" |
|
model = RwkvForCausalLM.from_pretrained(repo, torch_dtype=torch.bfloat16) |
|
|
|
state_dict = model.state_dict() |
|
converted = revert_state_dict(state_dict) |
|
name = repo.split("/")[-1] + ".bf16.pth" |
|
|
|
torch.save(converted, name) |
|
``` |
|
|