Chaofeng111 commited on
Commit
ef4d538
·
unverified ·
1 Parent(s): 2435bfe

ONNX export in .train() mode fix (#3362)

Browse files
Files changed (1) hide show
  1. models/export.py +2 -0
models/export.py CHANGED
@@ -97,6 +97,8 @@ if __name__ == '__main__':
97
  print(f'{prefix} starting export with onnx {onnx.__version__}...')
98
  f = opt.weights.replace('.pt', '.onnx') # filename
99
  torch.onnx.export(model, img, f, verbose=False, opset_version=opt.opset_version, input_names=['images'],
 
 
100
  dynamic_axes={'images': {0: 'batch', 2: 'height', 3: 'width'}, # size(1,3,640,640)
101
  'output': {0: 'batch', 2: 'y', 3: 'x'}} if opt.dynamic else None)
102
 
 
97
  print(f'{prefix} starting export with onnx {onnx.__version__}...')
98
  f = opt.weights.replace('.pt', '.onnx') # filename
99
  torch.onnx.export(model, img, f, verbose=False, opset_version=opt.opset_version, input_names=['images'],
100
+ training=torch.onnx.TrainingMode.TRAINING if opt.train else torch.onnx.TrainingMode.EVAL,
101
+ do_constant_folding=not opt.train,
102
  dynamic_axes={'images': {0: 'batch', 2: 'height', 3: 'width'}, # size(1,3,640,640)
103
  'output': {0: 'batch', 2: 'y', 3: 'x'}} if opt.dynamic else None)
104