Spaces:
Runtime error
Runtime error
#!/usr/bin/env python3 | |
# -*- coding:utf-8 -*- | |
import argparse | |
from collections import OrderedDict | |
import megengine as mge | |
import torch | |
def make_parser(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument("-w", "--weights", type=str, help="path of weight file") | |
parser.add_argument( | |
"-o", | |
"--output", | |
default="weight_mge.pkl", | |
type=str, | |
help="path of weight file", | |
) | |
return parser | |
def numpy_weights(weight_file): | |
torch_weights = torch.load(weight_file, map_location="cpu") | |
if "model" in torch_weights: | |
torch_weights = torch_weights["model"] | |
new_dict = OrderedDict() | |
for k, v in torch_weights.items(): | |
new_dict[k] = v.cpu().numpy() | |
return new_dict | |
def map_weights(weight_file, output_file): | |
torch_weights = numpy_weights(weight_file) | |
new_dict = OrderedDict() | |
for k, v in torch_weights.items(): | |
if "num_batches_tracked" in k: | |
print("drop: {}".format(k)) | |
continue | |
if k.endswith("bias"): | |
print("bias key: {}".format(k)) | |
v = v.reshape(1, -1, 1, 1) | |
new_dict[k] = v | |
elif "dconv" in k and "conv.weight" in k: | |
print("depthwise conv key: {}".format(k)) | |
cout, cin, k1, k2 = v.shape | |
v = v.reshape(cout, 1, cin, k1, k2) | |
new_dict[k] = v | |
else: | |
new_dict[k] = v | |
mge.save(new_dict, output_file) | |
print("save weights to {}".format(output_file)) | |
def main(): | |
parser = make_parser() | |
args = parser.parse_args() | |
map_weights(args.weights, args.output) | |
if __name__ == "__main__": | |
main() | |