chendl's picture
Add application file
0b7b08a
raw
history blame
1.66 kB
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
import argparse
from collections import OrderedDict
import megengine as mge
import torch
def make_parser():
parser = argparse.ArgumentParser()
parser.add_argument("-w", "--weights", type=str, help="path of weight file")
parser.add_argument(
"-o",
"--output",
default="weight_mge.pkl",
type=str,
help="path of weight file",
)
return parser
def numpy_weights(weight_file):
torch_weights = torch.load(weight_file, map_location="cpu")
if "model" in torch_weights:
torch_weights = torch_weights["model"]
new_dict = OrderedDict()
for k, v in torch_weights.items():
new_dict[k] = v.cpu().numpy()
return new_dict
def map_weights(weight_file, output_file):
torch_weights = numpy_weights(weight_file)
new_dict = OrderedDict()
for k, v in torch_weights.items():
if "num_batches_tracked" in k:
print("drop: {}".format(k))
continue
if k.endswith("bias"):
print("bias key: {}".format(k))
v = v.reshape(1, -1, 1, 1)
new_dict[k] = v
elif "dconv" in k and "conv.weight" in k:
print("depthwise conv key: {}".format(k))
cout, cin, k1, k2 = v.shape
v = v.reshape(cout, 1, cin, k1, k2)
new_dict[k] = v
else:
new_dict[k] = v
mge.save(new_dict, output_file)
print("save weights to {}".format(output_file))
def main():
parser = make_parser()
args = parser.parse_args()
map_weights(args.weights, args.output)
if __name__ == "__main__":
main()