Enable TensorFlow ops for `--nms` and `--agnostic-nms` (#7281)
Browse files* enable TensorFlow ops if flag --nms or --agnostic-nms is used
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* Update export.py
* Update export.py
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Glenn Jocher <[email protected]>
export.py
CHANGED
@@ -327,7 +327,7 @@ def export_pb(keras_model, im, file, prefix=colorstr('TensorFlow GraphDef:')):
|
|
327 |
LOGGER.info(f'\n{prefix} export failure: {e}')
|
328 |
|
329 |
|
330 |
-
def export_tflite(keras_model, im, file, int8, data,
|
331 |
# YOLOv5 TensorFlow Lite export
|
332 |
try:
|
333 |
import tensorflow as tf
|
@@ -343,13 +343,15 @@ def export_tflite(keras_model, im, file, int8, data, ncalib, prefix=colorstr('Te
|
|
343 |
if int8:
|
344 |
from models.tf import representative_dataset_gen
|
345 |
dataset = LoadImages(check_dataset(data)['train'], img_size=imgsz, auto=False) # representative data
|
346 |
-
converter.representative_dataset = lambda: representative_dataset_gen(dataset, ncalib)
|
347 |
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
|
348 |
converter.target_spec.supported_types = []
|
349 |
converter.inference_input_type = tf.uint8 # or tf.int8
|
350 |
converter.inference_output_type = tf.uint8 # or tf.int8
|
351 |
converter.experimental_new_quantizer = True
|
352 |
f = str(file).replace('.pt', '-int8.tflite')
|
|
|
|
|
353 |
|
354 |
tflite_model = converter.convert()
|
355 |
open(f, "wb").write(tflite_model)
|
@@ -524,7 +526,7 @@ def run(
|
|
524 |
if pb or tfjs: # pb prerequisite to tfjs
|
525 |
f[6] = export_pb(model, im, file)
|
526 |
if tflite or edgetpu:
|
527 |
-
f[7] = export_tflite(model, im, file, int8=int8 or edgetpu, data=data,
|
528 |
if edgetpu:
|
529 |
f[8] = export_edgetpu(model, im, file)
|
530 |
if tfjs:
|
|
|
327 |
LOGGER.info(f'\n{prefix} export failure: {e}')
|
328 |
|
329 |
|
330 |
+
def export_tflite(keras_model, im, file, int8, data, nms, agnostic_nms, prefix=colorstr('TensorFlow Lite:')):
|
331 |
# YOLOv5 TensorFlow Lite export
|
332 |
try:
|
333 |
import tensorflow as tf
|
|
|
343 |
if int8:
|
344 |
from models.tf import representative_dataset_gen
|
345 |
dataset = LoadImages(check_dataset(data)['train'], img_size=imgsz, auto=False) # representative data
|
346 |
+
converter.representative_dataset = lambda: representative_dataset_gen(dataset, ncalib=100)
|
347 |
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
|
348 |
converter.target_spec.supported_types = []
|
349 |
converter.inference_input_type = tf.uint8 # or tf.int8
|
350 |
converter.inference_output_type = tf.uint8 # or tf.int8
|
351 |
converter.experimental_new_quantizer = True
|
352 |
f = str(file).replace('.pt', '-int8.tflite')
|
353 |
+
if nms or agnostic_nms:
|
354 |
+
converter.target_spec.supported_ops.append(tf.lite.OpsSet.SELECT_TF_OPS)
|
355 |
|
356 |
tflite_model = converter.convert()
|
357 |
open(f, "wb").write(tflite_model)
|
|
|
526 |
if pb or tfjs: # pb prerequisite to tfjs
|
527 |
f[6] = export_pb(model, im, file)
|
528 |
if tflite or edgetpu:
|
529 |
+
f[7] = export_tflite(model, im, file, int8=int8 or edgetpu, data=data, nms=nms, agnostic_nms=agnostic_nms)
|
530 |
if edgetpu:
|
531 |
f[8] = export_edgetpu(model, im, file)
|
532 |
if tfjs:
|