glenn-jocher commited on
Commit
e97d129
·
unverified ·
1 Parent(s): f2de1ad

Update export.py with --train mode argument (#3066)

Browse files
Files changed (1) hide show
  1. models/export.py +3 -0
models/export.py CHANGED
@@ -29,6 +29,7 @@ if __name__ == '__main__':
29
  parser.add_argument('--device', default='cpu', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
30
  parser.add_argument('--half', action='store_true', help='FP16 half-precision export')
31
  parser.add_argument('--inplace', action='store_true', help='set YOLOv5 Detect() inplace=True')
 
32
  parser.add_argument('--dynamic', action='store_true', help='dynamic ONNX axes') # ONNX-only
33
  parser.add_argument('--simplify', action='store_true', help='simplify ONNX model') # ONNX-only
34
  opt = parser.parse_args()
@@ -53,6 +54,8 @@ if __name__ == '__main__':
53
  # Update model
54
  if opt.half:
55
  img, model = img.half(), model.half() # to FP16
 
 
56
  for k, m in model.named_modules():
57
  m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility
58
  if isinstance(m, models.common.Conv): # assign export-friendly activations
 
29
  parser.add_argument('--device', default='cpu', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
30
  parser.add_argument('--half', action='store_true', help='FP16 half-precision export')
31
  parser.add_argument('--inplace', action='store_true', help='set YOLOv5 Detect() inplace=True')
32
+ parser.add_argument('--train', action='store_true', help='model.train() mode')
33
  parser.add_argument('--dynamic', action='store_true', help='dynamic ONNX axes') # ONNX-only
34
  parser.add_argument('--simplify', action='store_true', help='simplify ONNX model') # ONNX-only
35
  opt = parser.parse_args()
 
54
  # Update model
55
  if opt.half:
56
  img, model = img.half(), model.half() # to FP16
57
+ if opt.train:
58
+ model.train() # training mode (no grid construction in Detect layer)
59
  for k, m in model.named_modules():
60
  m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility
61
  if isinstance(m, models.common.Conv): # assign export-friendly activations