TensorFlow.js export enhancements (#4905)
Browse files* Add arguments to TensorFlow NMS call
* Add regex substitution to reorder Identity_*
* Delete reorder in docstring
* Cleanup
* Cleanup2
* Removed `+ \` on string ends (not needed)
Co-authored-by: Glenn Jocher <[email protected]>
- export.py +27 -2
- models/tf.py +1 -1
export.py
CHANGED
@@ -14,7 +14,6 @@ Inference:
|
|
14 |
yolov5s.tflite
|
15 |
|
16 |
TensorFlow.js:
|
17 |
-
$ # Edit yolov5s_web_model/model.json to sort Identity* in ascending order
|
18 |
$ cd .. && git clone https://github.com/zldrobit/tfjs-yolov5-example.git && cd tfjs-yolov5-example
|
19 |
$ npm install
|
20 |
$ ln -s ../../yolov5/yolov5s_web_model public/yolov5s_web_model
|
@@ -213,16 +212,32 @@ def export_tfjs(keras_model, im, file, prefix=colorstr('TensorFlow.js:')):
|
|
213 |
# YOLOv5 TensorFlow.js export
|
214 |
try:
|
215 |
check_requirements(('tensorflowjs',))
|
|
|
216 |
import tensorflowjs as tfjs
|
217 |
|
218 |
print(f'\n{prefix} starting export with tensorflowjs {tfjs.__version__}...')
|
219 |
f = str(file).replace('.pt', '_web_model') # js dir
|
220 |
f_pb = file.with_suffix('.pb') # *.pb path
|
|
|
221 |
|
222 |
cmd = f"tensorflowjs_converter --input_format=tf_frozen_model " \
|
223 |
f"--output_node_names='Identity,Identity_1,Identity_2,Identity_3' {f_pb} {f}"
|
224 |
subprocess.run(cmd, shell=True)
|
225 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
226 |
print(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
|
227 |
except Exception as e:
|
228 |
print(f'\n{prefix} export failure: {e}')
|
@@ -243,6 +258,10 @@ def run(data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path'
|
|
243 |
dynamic=False, # ONNX/TF: dynamic axes
|
244 |
simplify=False, # ONNX: simplify model
|
245 |
opset=12, # ONNX: opset version
|
|
|
|
|
|
|
|
|
246 |
):
|
247 |
t = time.time()
|
248 |
include = [x.lower() for x in include]
|
@@ -290,7 +309,9 @@ def run(data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path'
|
|
290 |
if any(tf_exports):
|
291 |
pb, tflite, tfjs = tf_exports[1:]
|
292 |
assert not (tflite and tfjs), 'TFLite and TF.js models must be exported separately, please pass only one type.'
|
293 |
-
model = export_saved_model(model, im, file, dynamic, tf_nms=tfjs, agnostic_nms=tfjs
|
|
|
|
|
294 |
if pb or tfjs: # pb prerequisite to tfjs
|
295 |
export_pb(model, im, file)
|
296 |
if tflite:
|
@@ -319,6 +340,10 @@ def parse_opt():
|
|
319 |
parser.add_argument('--dynamic', action='store_true', help='ONNX/TF: dynamic axes')
|
320 |
parser.add_argument('--simplify', action='store_true', help='ONNX: simplify model')
|
321 |
parser.add_argument('--opset', type=int, default=13, help='ONNX: opset version')
|
|
|
|
|
|
|
|
|
322 |
parser.add_argument('--include', nargs='+',
|
323 |
default=['torchscript', 'onnx'],
|
324 |
help='available formats are (torchscript, onnx, coreml, saved_model, pb, tflite, tfjs)')
|
|
|
14 |
yolov5s.tflite
|
15 |
|
16 |
TensorFlow.js:
|
|
|
17 |
$ cd .. && git clone https://github.com/zldrobit/tfjs-yolov5-example.git && cd tfjs-yolov5-example
|
18 |
$ npm install
|
19 |
$ ln -s ../../yolov5/yolov5s_web_model public/yolov5s_web_model
|
|
|
212 |
# YOLOv5 TensorFlow.js export
|
213 |
try:
|
214 |
check_requirements(('tensorflowjs',))
|
215 |
+
import re
|
216 |
import tensorflowjs as tfjs
|
217 |
|
218 |
print(f'\n{prefix} starting export with tensorflowjs {tfjs.__version__}...')
|
219 |
f = str(file).replace('.pt', '_web_model') # js dir
|
220 |
f_pb = file.with_suffix('.pb') # *.pb path
|
221 |
+
f_json = f + '/model.json' # *.json path
|
222 |
|
223 |
cmd = f"tensorflowjs_converter --input_format=tf_frozen_model " \
|
224 |
f"--output_node_names='Identity,Identity_1,Identity_2,Identity_3' {f_pb} {f}"
|
225 |
subprocess.run(cmd, shell=True)
|
226 |
|
227 |
+
json = open(f_json).read()
|
228 |
+
with open(f_json, 'w') as j: # sort JSON Identity_* in ascending order
|
229 |
+
subst = re.sub(
|
230 |
+
r'{"outputs": {"Identity.?.?": {"name": "Identity.?.?"}, '
|
231 |
+
r'"Identity.?.?": {"name": "Identity.?.?"}, '
|
232 |
+
r'"Identity.?.?": {"name": "Identity.?.?"}, '
|
233 |
+
r'"Identity.?.?": {"name": "Identity.?.?"}}}',
|
234 |
+
r'{"outputs": {"Identity": {"name": "Identity"}, '
|
235 |
+
r'"Identity_1": {"name": "Identity_1"}, '
|
236 |
+
r'"Identity_2": {"name": "Identity_2"}, '
|
237 |
+
r'"Identity_3": {"name": "Identity_3"}}}',
|
238 |
+
json)
|
239 |
+
j.write(subst)
|
240 |
+
|
241 |
print(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
|
242 |
except Exception as e:
|
243 |
print(f'\n{prefix} export failure: {e}')
|
|
|
258 |
dynamic=False, # ONNX/TF: dynamic axes
|
259 |
simplify=False, # ONNX: simplify model
|
260 |
opset=12, # ONNX: opset version
|
261 |
+
topk_per_class=100, # TF.js NMS: topk per class to keep
|
262 |
+
topk_all=100, # TF.js NMS: topk for all classes to keep
|
263 |
+
iou_thres=0.45, # TF.js NMS: IoU threshold
|
264 |
+
conf_thres=0.25 # TF.js NMS: confidence threshold
|
265 |
):
|
266 |
t = time.time()
|
267 |
include = [x.lower() for x in include]
|
|
|
309 |
if any(tf_exports):
|
310 |
pb, tflite, tfjs = tf_exports[1:]
|
311 |
assert not (tflite and tfjs), 'TFLite and TF.js models must be exported separately, please pass only one type.'
|
312 |
+
model = export_saved_model(model, im, file, dynamic, tf_nms=tfjs, agnostic_nms=tfjs,
|
313 |
+
topk_per_class=topk_per_class, topk_all=topk_all, conf_thres=conf_thres,
|
314 |
+
iou_thres=iou_thres) # keras model
|
315 |
if pb or tfjs: # pb prerequisite to tfjs
|
316 |
export_pb(model, im, file)
|
317 |
if tflite:
|
|
|
340 |
parser.add_argument('--dynamic', action='store_true', help='ONNX/TF: dynamic axes')
|
341 |
parser.add_argument('--simplify', action='store_true', help='ONNX: simplify model')
|
342 |
parser.add_argument('--opset', type=int, default=13, help='ONNX: opset version')
|
343 |
+
parser.add_argument('--topk-per-class', type=int, default=100, help='TF.js NMS: topk per class to keep')
|
344 |
+
parser.add_argument('--topk-all', type=int, default=100, help='TF.js NMS: topk for all classes to keep')
|
345 |
+
parser.add_argument('--iou-thres', type=float, default=0.45, help='TF.js NMS: IoU threshold')
|
346 |
+
parser.add_argument('--conf-thres', type=float, default=0.25, help='TF.js NMS: confidence threshold')
|
347 |
parser.add_argument('--include', nargs='+',
|
348 |
default=['torchscript', 'onnx'],
|
349 |
help='available formats are (torchscript, onnx, coreml, saved_model, pb, tflite, tfjs)')
|
models/tf.py
CHANGED
@@ -367,7 +367,7 @@ class AgnosticNMS(keras.layers.Layer):
|
|
367 |
# TF Agnostic NMS
|
368 |
def call(self, input, topk_all, iou_thres, conf_thres):
|
369 |
# wrap map_fn to avoid TypeSpec related error https://stackoverflow.com/a/65809989/3036450
|
370 |
-
return tf.map_fn(self._nms, input,
|
371 |
fn_output_signature=(tf.float32, tf.float32, tf.float32, tf.int32),
|
372 |
name='agnostic_nms')
|
373 |
|
|
|
367 |
# TF Agnostic NMS
|
368 |
def call(self, input, topk_all, iou_thres, conf_thres):
|
369 |
# wrap map_fn to avoid TypeSpec related error https://stackoverflow.com/a/65809989/3036450
|
370 |
+
return tf.map_fn(lambda x: self._nms(x, topk_all, iou_thres, conf_thres), input,
|
371 |
fn_output_signature=(tf.float32, tf.float32, tf.float32, tf.int32),
|
372 |
name='agnostic_nms')
|
373 |
|