File size: 3,031 Bytes
fa6856c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
# ref: https://gist.github.com/benob/4850a0210b01672175942203aa36d300
import os
import json
import sys
import torch
import glob

# python test.py 2 xx/checkpoint-1000/ckpt/ outs

if len(sys.argv) != 4:
    print('usage: %s <new-shards> <input-model-path> <output-model-path>' % sys.argv[0], file=sys.stderr)
    sys.exit(1)

num_shards = int(sys.argv[1])
input_model_dir = sys.argv[2]
output_model_dir = sys.argv[3]

with open(os.path.join(input_model_dir, 'params.json'), 'r') as fp:
    params = json.loads(fp.read())

assert params['dim'] % num_shards == 0, "number of shards need to divide parameter dimension %d" % params['dim']

print('loading...')
checkpoints = [torch.load(path, map_location=torch.device('cpu')) for path in glob.glob(os.path.join(input_model_dir, '*.pth'))]

layer_kind = {
    'tok_embeddings': 'ParallelEmbedding',
    'output': 'ColumnParallelLinear',
    'attention.wq': 'ColumnParallelLinear',
    'attention.wk': 'ColumnParallelLinear',
    'attention.wv': 'ColumnParallelLinear',
    'attention.wo': 'RowParallelLinear',
    'feed_forward.w1': 'ColumnParallelLinear',
    'feed_forward.w2': 'RowParallelLinear',
    'feed_forward.w3': 'ColumnParallelLinear',
    'attention_norm': None,
    'ffn_norm': None,
    'norm': None,
    'rope.freqs': None,
}

output = [dict() for x in range(num_shards)]

print('converting...')
for key in checkpoints[0].keys():
    tensors = [m[key] for m in checkpoints]
    print(key)
    print('  in shapes=', [p.shape for p in tensors])
    for pattern, kind in layer_kind.items():
        if key.replace('.weight', '').endswith(pattern):
            print('  kind=', kind)
            if kind == 'ColumnParallelLinear':
                with torch.no_grad():
                    merged = torch.cat(tensors, 0)
                    slice_size = merged.shape[0] // num_shards
                    for rank in range(num_shards):
                        output[rank][key] = merged[slice_size * rank: slice_size * (rank + 1),:].clone().detach()
            elif kind in ('ParallelEmbedding', 'RowParallelLinear'):
                with torch.no_grad():
                    merged = torch.cat(tensors, 1)
                    slice_size = merged.shape[1] // num_shards
                    for rank in range(num_shards):
                        output[rank][key] = merged[:,slice_size * rank: slice_size * (rank + 1)].clone().detach()
            else:
                for rank in range(num_shards):
                    output[rank][key] = tensors[0]
            print('  out shapes=', [output[rank][key].shape for rank in range(num_shards)])
            print()
            break
    else:
        raise Exception('parameter name not recognized')

print('saving...')
os.makedirs(output_model_dir, exist_ok=True)
with open(os.path.join(output_model_dir, 'params.json'), 'w') as fp:
    fp.write(json.dumps(params))

for rank in range(num_shards):
    print(' ', rank)
    torch.save(output[rank], os.path.join(output_model_dir, 'consolidated.%02d.pth' % rank))

print('done.')