Spaces:
Runtime error
Runtime error
#!/usr/bin/env python3 | |
# -*- coding:utf-8 -*- | |
# Copyright (c) Megvii, Inc. and its affiliates. | |
import argparse | |
import megengine as mge | |
import numpy as np | |
from megengine import jit | |
from build import build_and_load | |
def make_parser(): | |
parser = argparse.ArgumentParser("YOLOX Demo Dump") | |
parser.add_argument("-n", "--name", type=str, default="yolox-s", help="model name") | |
parser.add_argument("-c", "--ckpt", default=None, type=str, help="ckpt for eval") | |
parser.add_argument( | |
"--dump_path", default="model.mge", help="path to save the dumped model" | |
) | |
return parser | |
def dump_static_graph(model, graph_name="model.mge"): | |
model.eval() | |
model.head.decode_in_inference = False | |
data = mge.Tensor(np.random.random((1, 3, 640, 640))) | |
def pred_func(data): | |
outputs = model(data) | |
return outputs | |
pred_func(data) | |
pred_func.dump( | |
graph_name, | |
arg_names=["data"], | |
optimize_for_inference=True, | |
enable_fuse_conv_bias_nonlinearity=True, | |
) | |
def main(args): | |
model = build_and_load(args.ckpt, name=args.name) | |
dump_static_graph(model, args.dump_path) | |
if __name__ == "__main__": | |
args = make_parser().parse_args() | |
main(args) | |