glenn-jocher commited on
Commit
79af114
·
unverified ·
1 Parent(s): 7b1643b

Automatic TFLite uint8 determination (#4515)

Browse files

* Auto TFLite uint8 detection

This PR automatically determines if TFLite models are uint8 quantized rather than accepting a manual argument.

The quantization determination is based on

@zldrobit
comment https://github.com/ultralytics/yolov5/pull/1127#issuecomment-901713847

* Cleanup

Files changed (1) hide show
  1. detect.py +5 -6
detect.py CHANGED
@@ -52,7 +52,6 @@ def run(weights='yolov5s.pt', # model.pt path(s)
52
  hide_labels=False, # hide labels
53
  hide_conf=False, # hide confidences
54
  half=False, # use FP16 half-precision inference
55
- tfl_int8=False, # INT8 quantized TFLite model
56
  ):
57
  save_img = not nosave and not source.endswith('.txt') # save inference images
58
  webcam = source.isnumeric() or source.endswith('.txt') or source.lower().startswith(
@@ -104,6 +103,7 @@ def run(weights='yolov5s.pt', # model.pt path(s)
104
  interpreter.allocate_tensors() # allocate
105
  input_details = interpreter.get_input_details() # inputs
106
  output_details = interpreter.get_output_details() # outputs
 
107
  imgsz = check_img_size(imgsz, s=stride) # check image size
108
 
109
  # Dataloader
@@ -145,15 +145,15 @@ def run(weights='yolov5s.pt', # model.pt path(s)
145
  elif saved_model:
146
  pred = model(imn, training=False).numpy()
147
  elif tflite:
148
- if tfl_int8:
149
  scale, zero_point = input_details[0]['quantization']
150
- imn = (imn / scale + zero_point).astype(np.uint8)
151
  interpreter.set_tensor(input_details[0]['index'], imn)
152
  interpreter.invoke()
153
  pred = interpreter.get_tensor(output_details[0]['index'])
154
- if tfl_int8:
155
  scale, zero_point = output_details[0]['quantization']
156
- pred = (pred.astype(np.float32) - zero_point) * scale
157
  pred[..., 0] *= imgsz[1] # x
158
  pred[..., 1] *= imgsz[0] # y
159
  pred[..., 2] *= imgsz[1] # w
@@ -268,7 +268,6 @@ def parse_opt():
268
  parser.add_argument('--hide-labels', default=False, action='store_true', help='hide labels')
269
  parser.add_argument('--hide-conf', default=False, action='store_true', help='hide confidences')
270
  parser.add_argument('--half', action='store_true', help='use FP16 half-precision inference')
271
- parser.add_argument('--tfl-int8', action='store_true', help='INT8 quantized TFLite model')
272
  opt = parser.parse_args()
273
  opt.imgsz *= 2 if len(opt.imgsz) == 1 else 1 # expand
274
  return opt
 
52
  hide_labels=False, # hide labels
53
  hide_conf=False, # hide confidences
54
  half=False, # use FP16 half-precision inference
 
55
  ):
56
  save_img = not nosave and not source.endswith('.txt') # save inference images
57
  webcam = source.isnumeric() or source.endswith('.txt') or source.lower().startswith(
 
103
  interpreter.allocate_tensors() # allocate
104
  input_details = interpreter.get_input_details() # inputs
105
  output_details = interpreter.get_output_details() # outputs
106
+ int8 = input_details[0]['dtype'] == np.uint8 # is TFLite quantized uint8 model
107
  imgsz = check_img_size(imgsz, s=stride) # check image size
108
 
109
  # Dataloader
 
145
  elif saved_model:
146
  pred = model(imn, training=False).numpy()
147
  elif tflite:
148
+ if int8:
149
  scale, zero_point = input_details[0]['quantization']
150
+ imn = (imn / scale + zero_point).astype(np.uint8) # de-scale
151
  interpreter.set_tensor(input_details[0]['index'], imn)
152
  interpreter.invoke()
153
  pred = interpreter.get_tensor(output_details[0]['index'])
154
+ if int8:
155
  scale, zero_point = output_details[0]['quantization']
156
+ pred = (pred.astype(np.float32) - zero_point) * scale # re-scale
157
  pred[..., 0] *= imgsz[1] # x
158
  pred[..., 1] *= imgsz[0] # y
159
  pred[..., 2] *= imgsz[1] # w
 
268
  parser.add_argument('--hide-labels', default=False, action='store_true', help='hide labels')
269
  parser.add_argument('--hide-conf', default=False, action='store_true', help='hide confidences')
270
  parser.add_argument('--half', action='store_true', help='use FP16 half-precision inference')
 
271
  opt = parser.parse_args()
272
  opt.imgsz *= 2 if len(opt.imgsz) == 1 else 1 # expand
273
  return opt