Commit
·
30757c1
1
Parent(s):
9491ede
feat(tools): add batch-size option for onnx conversion process (#582)
Browse files- tools/export_onnx.py +13 -2
tools/export_onnx.py
CHANGED
@@ -28,6 +28,10 @@ def make_parser():
|
|
28 |
parser.add_argument(
|
29 |
"-o", "--opset", default=11, type=int, help="onnx opset version"
|
30 |
)
|
|
|
|
|
|
|
|
|
31 |
parser.add_argument("--no-onnxsim", action="store_true", help="use onnxsim or not")
|
32 |
parser.add_argument(
|
33 |
"-f",
|
@@ -77,13 +81,16 @@ def main():
|
|
77 |
model.head.decode_in_inference = False
|
78 |
|
79 |
logger.info("loading checkpoint done.")
|
80 |
-
dummy_input = torch.randn(
|
|
|
81 |
torch.onnx._export(
|
82 |
model,
|
83 |
dummy_input,
|
84 |
args.output_name,
|
85 |
input_names=[args.input],
|
86 |
output_names=[args.output],
|
|
|
|
|
87 |
opset_version=args.opset,
|
88 |
)
|
89 |
logger.info("generated onnx model named {}".format(args.output_name))
|
@@ -93,9 +100,13 @@ def main():
|
|
93 |
|
94 |
from onnxsim import simplify
|
95 |
|
|
|
|
|
96 |
# use onnxsimplify to reduce reduent model.
|
97 |
onnx_model = onnx.load(args.output_name)
|
98 |
-
model_simp, check = simplify(onnx_model
|
|
|
|
|
99 |
assert check, "Simplified ONNX model could not be validated"
|
100 |
onnx.save(model_simp, args.output_name)
|
101 |
logger.info("generated simplified onnx model named {}".format(args.output_name))
|
|
|
28 |
parser.add_argument(
|
29 |
"-o", "--opset", default=11, type=int, help="onnx opset version"
|
30 |
)
|
31 |
+
parser.add_argument("--batch-size", type=int, default=1, help="batch size")
|
32 |
+
parser.add_argument(
|
33 |
+
"--dynamic", action="store_true", help="whether the input shape should be dynamic or not"
|
34 |
+
)
|
35 |
parser.add_argument("--no-onnxsim", action="store_true", help="use onnxsim or not")
|
36 |
parser.add_argument(
|
37 |
"-f",
|
|
|
81 |
model.head.decode_in_inference = False
|
82 |
|
83 |
logger.info("loading checkpoint done.")
|
84 |
+
dummy_input = torch.randn(args.batch_size, 3, exp.test_size[0], exp.test_size[1])
|
85 |
+
|
86 |
torch.onnx._export(
|
87 |
model,
|
88 |
dummy_input,
|
89 |
args.output_name,
|
90 |
input_names=[args.input],
|
91 |
output_names=[args.output],
|
92 |
+
dynamic_axes={args.input: {0: 'batch'},
|
93 |
+
args.output: {0: 'batch'}} if args.dynamic else None,
|
94 |
opset_version=args.opset,
|
95 |
)
|
96 |
logger.info("generated onnx model named {}".format(args.output_name))
|
|
|
100 |
|
101 |
from onnxsim import simplify
|
102 |
|
103 |
+
input_shapes = {args.input: list(dummy_input.shape)} if args.dynamic else None
|
104 |
+
|
105 |
# use onnxsimplify to reduce reduent model.
|
106 |
onnx_model = onnx.load(args.output_name)
|
107 |
+
model_simp, check = simplify(onnx_model,
|
108 |
+
dynamic_input_shape=args.dynamic,
|
109 |
+
input_shapes=input_shapes)
|
110 |
assert check, "Simplified ONNX model could not be validated"
|
111 |
onnx.save(model_simp, args.output_name)
|
112 |
logger.info("generated simplified onnx model named {}".format(args.output_name))
|