File size: 1,661 Bytes
0b7b08a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
import argparse
from collections import OrderedDict

import megengine as mge
import torch


def make_parser():
    parser = argparse.ArgumentParser()
    parser.add_argument("-w", "--weights", type=str, help="path of weight file")
    parser.add_argument(
        "-o",
        "--output",
        default="weight_mge.pkl",
        type=str,
        help="path of weight file",
    )
    return parser


def numpy_weights(weight_file):
    torch_weights = torch.load(weight_file, map_location="cpu")
    if "model" in torch_weights:
        torch_weights = torch_weights["model"]
    new_dict = OrderedDict()
    for k, v in torch_weights.items():
        new_dict[k] = v.cpu().numpy()
    return new_dict


def map_weights(weight_file, output_file):
    torch_weights = numpy_weights(weight_file)

    new_dict = OrderedDict()
    for k, v in torch_weights.items():
        if "num_batches_tracked" in k:
            print("drop: {}".format(k))
            continue
        if k.endswith("bias"):
            print("bias key: {}".format(k))
            v = v.reshape(1, -1, 1, 1)
            new_dict[k] = v
        elif "dconv" in k and "conv.weight" in k:
            print("depthwise conv key: {}".format(k))
            cout, cin, k1, k2 = v.shape
            v = v.reshape(cout, 1, cin, k1, k2)
            new_dict[k] = v
        else:
            new_dict[k] = v

    mge.save(new_dict, output_file)
    print("save weights to {}".format(output_file))


def main():
    parser = make_parser()
    args = parser.parse_args()
    map_weights(args.weights, args.output)


if __name__ == "__main__":
    main()