glenn-jocher commited on
Commit
b74dd4b
·
unverified ·
1 Parent(s): fcb225c

Add `--int8` argument (#4799)

Browse files

* Add `--int8` argument

* parents[0] bug fix

* Fix order

Files changed (1) hide show
  1. export.py +8 -5
export.py CHANGED
@@ -33,7 +33,8 @@ from torch.utils.mobile_optimizer import optimize_for_mobile
33
 
34
  FILE = Path(__file__).resolve()
35
  ROOT = FILE.parents[0] # yolov5/ dir
36
- sys.path.append(ROOT.as_posix()) # add yolov5/ to path
 
37
 
38
  from models.common import Conv
39
  from models.experimental import attempt_load
@@ -174,7 +175,7 @@ def export_pb(keras_model, im, file, prefix=colorstr('TensorFlow GraphDef:')):
174
  print(f'\n{prefix} export failure: {e}')
175
 
176
 
177
- def export_tflite(keras_model, im, file, tfl_int8, data, ncalib, prefix=colorstr('TensorFlow Lite:')):
178
  # YOLOv5 TensorFlow Lite export
179
  try:
180
  import tensorflow as tf
@@ -187,7 +188,7 @@ def export_tflite(keras_model, im, file, tfl_int8, data, ncalib, prefix=colorstr
187
  converter = tf.lite.TFLiteConverter.from_keras_model(keras_model)
188
  converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS]
189
  converter.optimizations = [tf.lite.Optimize.DEFAULT]
190
- if tfl_int8:
191
  dataset = LoadImages(check_dataset(data)['train'], img_size=imgsz, auto=False) # representative data
192
  converter.representative_dataset = lambda: representative_dataset_gen(dataset, ncalib)
193
  converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
@@ -234,7 +235,8 @@ def run(data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path'
234
  inplace=False, # set YOLOv5 Detect() inplace=True
235
  train=False, # model.train() mode
236
  optimize=False, # TorchScript: optimize for mobile
237
- dynamic=False, # ONNX: dynamic axes
 
238
  simplify=False, # ONNX: simplify model
239
  opset=12, # ONNX: opset version
240
  ):
@@ -288,7 +290,7 @@ def run(data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path'
288
  if pb or tfjs: # pb prerequisite to tfjs
289
  export_pb(model, im, file)
290
  if tflite:
291
- export_tflite(model, im, file, tfl_int8=False, data=data, ncalib=100)
292
  if tfjs:
293
  export_tfjs(model, im, file)
294
 
@@ -309,6 +311,7 @@ def parse_opt():
309
  parser.add_argument('--inplace', action='store_true', help='set YOLOv5 Detect() inplace=True')
310
  parser.add_argument('--train', action='store_true', help='model.train() mode')
311
  parser.add_argument('--optimize', action='store_true', help='TorchScript: optimize for mobile')
 
312
  parser.add_argument('--dynamic', action='store_true', help='ONNX/TF: dynamic axes')
313
  parser.add_argument('--simplify', action='store_true', help='ONNX: simplify model')
314
  parser.add_argument('--opset', type=int, default=13, help='ONNX: opset version')
 
33
 
34
  FILE = Path(__file__).resolve()
35
  ROOT = FILE.parents[0] # yolov5/ dir
36
+ if str(ROOT) not in sys.path:
37
+ sys.path.append(str(ROOT)) # add ROOT to PATH
38
 
39
  from models.common import Conv
40
  from models.experimental import attempt_load
 
175
  print(f'\n{prefix} export failure: {e}')
176
 
177
 
178
+ def export_tflite(keras_model, im, file, int8, data, ncalib, prefix=colorstr('TensorFlow Lite:')):
179
  # YOLOv5 TensorFlow Lite export
180
  try:
181
  import tensorflow as tf
 
188
  converter = tf.lite.TFLiteConverter.from_keras_model(keras_model)
189
  converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS]
190
  converter.optimizations = [tf.lite.Optimize.DEFAULT]
191
+ if int8:
192
  dataset = LoadImages(check_dataset(data)['train'], img_size=imgsz, auto=False) # representative data
193
  converter.representative_dataset = lambda: representative_dataset_gen(dataset, ncalib)
194
  converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
 
235
  inplace=False, # set YOLOv5 Detect() inplace=True
236
  train=False, # model.train() mode
237
  optimize=False, # TorchScript: optimize for mobile
238
+ int8=False, # CoreML/TF INT8 quantization
239
+ dynamic=False, # ONNX/TF: dynamic axes
240
  simplify=False, # ONNX: simplify model
241
  opset=12, # ONNX: opset version
242
  ):
 
290
  if pb or tfjs: # pb prerequisite to tfjs
291
  export_pb(model, im, file)
292
  if tflite:
293
+ export_tflite(model, im, file, int8=int8, data=data, ncalib=100)
294
  if tfjs:
295
  export_tfjs(model, im, file)
296
 
 
311
  parser.add_argument('--inplace', action='store_true', help='set YOLOv5 Detect() inplace=True')
312
  parser.add_argument('--train', action='store_true', help='model.train() mode')
313
  parser.add_argument('--optimize', action='store_true', help='TorchScript: optimize for mobile')
314
+ parser.add_argument('--int8', action='store_true', help='CoreML/TF INT8 quantization')
315
  parser.add_argument('--dynamic', action='store_true', help='ONNX/TF: dynamic axes')
316
  parser.add_argument('--simplify', action='store_true', help='ONNX: simplify model')
317
  parser.add_argument('--opset', type=int, default=13, help='ONNX: opset version')