developer0hye commited on
Commit
30757c1
·
1 Parent(s): 9491ede

feat(tools): add batch-size option for onnx conversion process (#582)

Browse files
Files changed (1) hide show
  1. tools/export_onnx.py +13 -2
tools/export_onnx.py CHANGED
@@ -28,6 +28,10 @@ def make_parser():
28
  parser.add_argument(
29
  "-o", "--opset", default=11, type=int, help="onnx opset version"
30
  )
 
 
 
 
31
  parser.add_argument("--no-onnxsim", action="store_true", help="use onnxsim or not")
32
  parser.add_argument(
33
  "-f",
@@ -77,13 +81,16 @@ def main():
77
  model.head.decode_in_inference = False
78
 
79
  logger.info("loading checkpoint done.")
80
- dummy_input = torch.randn(1, 3, exp.test_size[0], exp.test_size[1])
 
81
  torch.onnx._export(
82
  model,
83
  dummy_input,
84
  args.output_name,
85
  input_names=[args.input],
86
  output_names=[args.output],
 
 
87
  opset_version=args.opset,
88
  )
89
  logger.info("generated onnx model named {}".format(args.output_name))
@@ -93,9 +100,13 @@ def main():
93
 
94
  from onnxsim import simplify
95
 
 
 
96
  # use onnxsimplify to reduce reduent model.
97
  onnx_model = onnx.load(args.output_name)
98
- model_simp, check = simplify(onnx_model)
 
 
99
  assert check, "Simplified ONNX model could not be validated"
100
  onnx.save(model_simp, args.output_name)
101
  logger.info("generated simplified onnx model named {}".format(args.output_name))
 
28
  parser.add_argument(
29
  "-o", "--opset", default=11, type=int, help="onnx opset version"
30
  )
31
+ parser.add_argument("--batch-size", type=int, default=1, help="batch size")
32
+ parser.add_argument(
33
+ "--dynamic", action="store_true", help="whether the input shape should be dynamic or not"
34
+ )
35
  parser.add_argument("--no-onnxsim", action="store_true", help="use onnxsim or not")
36
  parser.add_argument(
37
  "-f",
 
81
  model.head.decode_in_inference = False
82
 
83
  logger.info("loading checkpoint done.")
84
+ dummy_input = torch.randn(args.batch_size, 3, exp.test_size[0], exp.test_size[1])
85
+
86
  torch.onnx._export(
87
  model,
88
  dummy_input,
89
  args.output_name,
90
  input_names=[args.input],
91
  output_names=[args.output],
92
+ dynamic_axes={args.input: {0: 'batch'},
93
+ args.output: {0: 'batch'}} if args.dynamic else None,
94
  opset_version=args.opset,
95
  )
96
  logger.info("generated onnx model named {}".format(args.output_name))
 
100
 
101
  from onnxsim import simplify
102
 
103
+ input_shapes = {args.input: list(dummy_input.shape)} if args.dynamic else None
104
+
105
  # use onnxsimplify to reduce reduent model.
106
  onnx_model = onnx.load(args.output_name)
107
+ model_simp, check = simplify(onnx_model,
108
+ dynamic_input_shape=args.dynamic,
109
+ input_shapes=input_shapes)
110
  assert check, "Simplified ONNX model could not be validated"
111
  onnx.save(model_simp, args.output_name)
112
  logger.info("generated simplified onnx model named {}".format(args.output_name))