glenn-jocher commited on
Commit
bfd51f6
·
2 Parent(s): 9a9333d 615d6d0

Merge remote-tracking branch 'origin/master'

Browse files
Files changed (1) hide show
  1. models/{onnx_export.py → export.py} +29 -17
models/{onnx_export.py → export.py} RENAMED
@@ -1,7 +1,7 @@
1
- """Exports a pytorch *.pt model to *.onnx format
2
 
3
  Usage:
4
- $ export PYTHONPATH="$PWD" && python models/onnx_export.py --weights ./weights/yolov5s.pt --img 640 --batch 1
5
  """
6
 
7
  import argparse
@@ -17,27 +17,39 @@ if __name__ == '__main__':
17
  parser.add_argument('--img-size', nargs='+', type=int, default=[640, 640], help='image size')
18
  parser.add_argument('--batch-size', type=int, default=1, help='batch size')
19
  opt = parser.parse_args()
20
- opt.img_size *= 2 if len(opt.img_size) == 1 else 1
21
  print(opt)
22
 
23
- # Parameters
24
- f = opt.weights.replace('.pt', '.onnx') # onnx filename
25
  img = torch.zeros((opt.batch_size, 3, *opt.img_size)) # image size, (1, 3, 320, 192) iDetection
26
 
27
- # Load pytorch model
28
  google_utils.attempt_download(opt.weights)
29
  model = torch.load(opt.weights, map_location=torch.device('cpu'))['model'].float()
30
  model.eval()
31
- model.fuse()
32
-
33
- # Export to onnx
34
  model.model[-1].export = True # set Detect() layer export=True
35
  _ = model(img) # dry run
36
- torch.onnx.export(model, img, f, verbose=False, opset_version=11, input_names=['images'],
37
- output_names=['output']) # output_names=['classes', 'boxes']
38
-
39
- # Check onnx model
40
- model = onnx.load(f) # load onnx model
41
- onnx.checker.check_model(model) # check onnx model
42
- print(onnx.helper.printable_graph(model.graph)) # print a human readable representation of the graph
43
- print('Export complete. ONNX model saved to %s\nView with https://github.com/lutzroeder/netron' % f)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Exports a YOLOv5 *.pt model to *.onnx and *.torchscript formats
2
 
3
  Usage:
4
+ $ export PYTHONPATH="$PWD" && python models/export.py --weights ./weights/yolov5s.pt --img 640 --batch 1
5
  """
6
 
7
  import argparse
 
17
  parser.add_argument('--img-size', nargs='+', type=int, default=[640, 640], help='image size')
18
  parser.add_argument('--batch-size', type=int, default=1, help='batch size')
19
  opt = parser.parse_args()
20
+ opt.img_size *= 2 if len(opt.img_size) == 1 else 1 # expand
21
  print(opt)
22
 
23
+ # Input
 
24
  img = torch.zeros((opt.batch_size, 3, *opt.img_size)) # image size, (1, 3, 320, 192) iDetection
25
 
26
+ # Load PyTorch model
27
  google_utils.attempt_download(opt.weights)
28
  model = torch.load(opt.weights, map_location=torch.device('cpu'))['model'].float()
29
  model.eval()
 
 
 
30
  model.model[-1].export = True # set Detect() layer export=True
31
  _ = model(img) # dry run
32
+
33
+ # Export to torchscript
34
+ try:
35
+ f = opt.weights.replace('.pt', '.torchscript') # filename
36
+ ts = torch.jit.trace(model, img)
37
+ ts.save(f)
38
+ print('Torchscript export success, saved as %s' % f)
39
+ except:
40
+ print('Torchscript export failed.')
41
+
42
+ # Export to ONNX
43
+ try:
44
+ f = opt.weights.replace('.pt', '.onnx') # filename
45
+ model.fuse() # only for ONNX
46
+ torch.onnx.export(model, img, f, verbose=False, opset_version=11, input_names=['images'],
47
+ output_names=['output']) # output_names=['classes', 'boxes']
48
+
49
+ # Checks
50
+ onnx_model = onnx.load(f) # load onnx model
51
+ onnx.checker.check_model(onnx_model) # check onnx model
52
+ print(onnx.helper.printable_graph(onnx_model.graph)) # print a human readable representation of the graph
53
+ print('ONNX export success, saved as %s\nView with https://github.com/lutzroeder/netron' % f)
54
+ except:
55
+ print('ONNX export failed.')