|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
import struct |
|
import torch |
|
from typing import Dict |
|
|
|
def parse_args(): |
|
parser = argparse.ArgumentParser(description='Convert an RWKV model checkpoint in PyTorch format to an rwkv.cpp compatible file') |
|
parser.add_argument('src_path', help='Path to PyTorch checkpoint file') |
|
parser.add_argument('dest_path', help='Path to rwkv.cpp checkpoint file, will be overwritten') |
|
parser.add_argument('data_type', help='Data type, float16 or float32', type=str, choices=['float16', 'float32'], default='float32') |
|
return parser.parse_args() |
|
|
|
def get_layer_count(state_dict: Dict[str, torch.Tensor]) -> int: |
|
n_layer = 0 |
|
|
|
while f'blocks.{n_layer}.ln1.weight' in state_dict: |
|
n_layer += 1 |
|
|
|
assert n_layer > 0 |
|
|
|
return n_layer |
|
|
|
def write_state_dict(state_dict: Dict[str, torch.Tensor], dest_path: str, data_type: str) -> None: |
|
emb_weight: torch.Tensor = state_dict['emb.weight'] |
|
|
|
n_layer = get_layer_count(state_dict) |
|
n_vocab = emb_weight.shape[0] |
|
n_embed = emb_weight.shape[1] |
|
|
|
with open(dest_path, 'wb') as out_file: |
|
out_file.write(struct.pack( |
|
|
|
'=iiiiii', |
|
|
|
0x67676d66, |
|
101, |
|
n_vocab, |
|
n_embed, |
|
n_layer, |
|
1 if data_type == 'float16' else 0 |
|
)) |
|
|
|
for k in state_dict.keys(): |
|
tensor = state_dict[k].float() |
|
|
|
|
|
if '.time_' in k: |
|
|
|
tensor = tensor.squeeze() |
|
|
|
if '.time_decay' in k: |
|
tensor = -torch.exp(tensor) |
|
|
|
|
|
if data_type == 'float16' and len(tensor.shape) > 1: |
|
tensor = tensor.half() |
|
|
|
shape = tensor.shape |
|
|
|
print(f'Writing {k}, shape {shape}, type {tensor.dtype}') |
|
|
|
k_encoded: bytes = k.encode('utf-8') |
|
|
|
out_file.write(struct.pack( |
|
'=iii', |
|
len(shape), |
|
len(k_encoded), |
|
1 if tensor.dtype == torch.float16 else 0 |
|
)) |
|
|
|
|
|
|
|
|
|
|
|
for dim in reversed(tensor.shape): |
|
out_file.write(struct.pack('=i', dim)) |
|
|
|
out_file.write(k_encoded) |
|
|
|
tensor.numpy().tofile(out_file) |
|
|
|
def main() -> None: |
|
args = parse_args() |
|
|
|
print(f'Reading {args.src_path}') |
|
|
|
state_dict: Dict[str, torch.Tensor] = torch.load(args.src_path, map_location='cpu') |
|
|
|
write_state_dict(state_dict, args.dest_path, args.data_type) |
|
|
|
print('Done') |
|
|
|
if __name__ == "__main__": |
|
main() |