Sam_S Samridha Shrestha glenn-jocher commited on
Commit
044daaf
·
unverified ·
1 Parent(s): 317f2cc

Add `output_names` argument for ONNX export with dynamic axes (#3456)

Browse files

* Add output names & dynamic axes for onnx export

Add output_names and dynamic_axes names for all outputs in torch.onnx.export. The first four outputs of the model will have names output0, output1, output2, output3

* use first output only + cleanup

Co-authored-by: Samridha Shrestha <[email protected]>
Co-authored-by: Glenn Jocher <[email protected]>

Files changed (1) hide show
  1. models/export.py +6 -3
models/export.py CHANGED
@@ -96,11 +96,14 @@ if __name__ == '__main__':
96
 
97
  print(f'{prefix} starting export with onnx {onnx.__version__}...')
98
  f = opt.weights.replace('.pt', '.onnx') # filename
99
- torch.onnx.export(model, img, f, verbose=False, opset_version=opt.opset_version, input_names=['images'],
100
  training=torch.onnx.TrainingMode.TRAINING if opt.train else torch.onnx.TrainingMode.EVAL,
101
  do_constant_folding=not opt.train,
102
- dynamic_axes={'images': {0: 'batch', 2: 'height', 3: 'width'}, # size(1,3,640,640)
103
- 'output': {0: 'batch', 2: 'y', 3: 'x'}} if opt.dynamic else None)
 
 
 
104
 
105
  # Checks
106
  model_onnx = onnx.load(f) # load onnx model
 
96
 
97
  print(f'{prefix} starting export with onnx {onnx.__version__}...')
98
  f = opt.weights.replace('.pt', '.onnx') # filename
99
+ torch.onnx.export(model, img, f, verbose=False, opset_version=opt.opset_version,
100
  training=torch.onnx.TrainingMode.TRAINING if opt.train else torch.onnx.TrainingMode.EVAL,
101
  do_constant_folding=not opt.train,
102
+ input_names=['images'],
103
+ output_names=['output'],
104
+ dynamic_axes={'images': {0: 'batch', 2: 'height', 3: 'width'}, # shape(1,3,640,640)
105
+ 'output': {0: 'batch', 1: 'anchors'} # shape(1,25200,85)
106
+ } if opt.dynamic else None)
107
 
108
  # Checks
109
  model_onnx = onnx.load(f) # load onnx model