Tim Stokman Tim glenn-jocher commited on
Commit
1df8c6c
·
unverified ·
1 Parent(s): 238583b

Fix ONNX dynamic axes export support with onnx simplifier, make onnx simplifier optional (#2856)

Browse files

* Ensure dynamic export works succesfully, onnx simplifier optional

* Update export.py

* add dashes

Co-authored-by: Tim <[email protected]>
Co-authored-by: Glenn Jocher <[email protected]>

Files changed (1) hide show
  1. models/export.py +19 -15
models/export.py CHANGED
@@ -21,12 +21,13 @@ from utils.torch_utils import select_device
21
 
22
  if __name__ == '__main__':
23
  parser = argparse.ArgumentParser()
24
- parser.add_argument('--weights', type=str, default='./yolov5s.pt', help='weights path') # from yolov5/models/
25
  parser.add_argument('--img-size', nargs='+', type=int, default=[640, 640], help='image size') # height, width
26
  parser.add_argument('--batch-size', type=int, default=1, help='batch size')
27
- parser.add_argument('--dynamic', action='store_true', help='dynamic ONNX axes')
28
  parser.add_argument('--grid', action='store_true', help='export Detect() layer grid')
29
  parser.add_argument('--device', default='cpu', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
 
 
30
  opt = parser.parse_args()
31
  opt.img_size *= 2 if len(opt.img_size) == 1 else 1 # expand
32
  print(opt)
@@ -58,7 +59,7 @@ if __name__ == '__main__':
58
  model.model[-1].export = not opt.grid # set Detect() layer grid export
59
  y = model(img) # dry run
60
 
61
- # TorchScript export
62
  prefix = colorstr('TorchScript:')
63
  try:
64
  print(f'\n{prefix} starting export with torch {torch.__version__}...')
@@ -69,7 +70,7 @@ if __name__ == '__main__':
69
  except Exception as e:
70
  print(f'{prefix} export failure: {e}')
71
 
72
- # ONNX export
73
  prefix = colorstr('ONNX:')
74
  try:
75
  import onnx
@@ -87,21 +88,24 @@ if __name__ == '__main__':
87
  # print(onnx.helper.printable_graph(model_onnx.graph)) # print
88
 
89
  # Simplify
90
- try:
91
- check_requirements(['onnx-simplifier'])
92
- import onnxsim
93
-
94
- print(f'{prefix} simplifying with onnx-simplifier {onnxsim.__version__}...')
95
- model_onnx, check = onnxsim.simplify(model_onnx)
96
- assert check, 'assert check failed'
97
- onnx.save(model_onnx, f)
98
- except Exception as e:
99
- print(f'{prefix} simplifier failure: {e}')
 
 
 
100
  print(f'{prefix} export success, saved as {f}')
101
  except Exception as e:
102
  print(f'{prefix} export failure: {e}')
103
 
104
- # CoreML export
105
  prefix = colorstr('CoreML:')
106
  try:
107
  import coremltools as ct
 
21
 
22
  if __name__ == '__main__':
23
  parser = argparse.ArgumentParser()
24
+ parser.add_argument('--weights', type=str, default='./yolov5s.pt', help='weights path')
25
  parser.add_argument('--img-size', nargs='+', type=int, default=[640, 640], help='image size') # height, width
26
  parser.add_argument('--batch-size', type=int, default=1, help='batch size')
 
27
  parser.add_argument('--grid', action='store_true', help='export Detect() layer grid')
28
  parser.add_argument('--device', default='cpu', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
29
+ parser.add_argument('--dynamic', action='store_true', help='dynamic ONNX axes') # ONNX-only
30
+ parser.add_argument('--simplify', action='store_true', help='simplify ONNX model') # ONNX-only
31
  opt = parser.parse_args()
32
  opt.img_size *= 2 if len(opt.img_size) == 1 else 1 # expand
33
  print(opt)
 
59
  model.model[-1].export = not opt.grid # set Detect() layer grid export
60
  y = model(img) # dry run
61
 
62
+ # TorchScript export -----------------------------------------------------------------------------------------------
63
  prefix = colorstr('TorchScript:')
64
  try:
65
  print(f'\n{prefix} starting export with torch {torch.__version__}...')
 
70
  except Exception as e:
71
  print(f'{prefix} export failure: {e}')
72
 
73
+ # ONNX export ------------------------------------------------------------------------------------------------------
74
  prefix = colorstr('ONNX:')
75
  try:
76
  import onnx
 
88
  # print(onnx.helper.printable_graph(model_onnx.graph)) # print
89
 
90
  # Simplify
91
+ if opt.simplify:
92
+ try:
93
+ check_requirements(['onnx-simplifier'])
94
+ import onnxsim
95
+
96
+ print(f'{prefix} simplifying with onnx-simplifier {onnxsim.__version__}...')
97
+ model_onnx, check = onnxsim.simplify(model_onnx,
98
+ dynamic_input_shape=opt.dynamic,
99
+ input_shapes={'images': list(img.shape)} if opt.dynamic else None)
100
+ assert check, 'assert check failed'
101
+ onnx.save(model_onnx, f)
102
+ except Exception as e:
103
+ print(f'{prefix} simplifier failure: {e}')
104
  print(f'{prefix} export success, saved as {f}')
105
  except Exception as e:
106
  print(f'{prefix} export failure: {e}')
107
 
108
+ # CoreML export ----------------------------------------------------------------------------------------------------
109
  prefix = colorstr('CoreML:')
110
  try:
111
  import coremltools as ct