Fix FP32 TensorRT model export (#8046)
Browse filesFixed FP32 TRT model export
Co-authored-by: Glenn Jocher <[email protected]>
export.py
CHANGED
@@ -264,8 +264,8 @@ def export_engine(model, im, file, train, half, simplify, workspace=4, verbose=F
|
|
264 |
for out in outputs:
|
265 |
LOGGER.info(f'{prefix}\toutput "{out.name}" with shape {out.shape} and dtype {out.dtype}')
|
266 |
|
267 |
-
LOGGER.info(f'{prefix} building FP{16 if builder.platform_has_fast_fp16 else 32} engine in {f}')
|
268 |
-
if builder.platform_has_fast_fp16:
|
269 |
config.set_flag(trt.BuilderFlag.FP16)
|
270 |
with builder.build_engine(network, config) as engine, open(f, 'wb') as t:
|
271 |
t.write(engine.serialize())
|
|
|
264 |
for out in outputs:
|
265 |
LOGGER.info(f'{prefix}\toutput "{out.name}" with shape {out.shape} and dtype {out.dtype}')
|
266 |
|
267 |
+
LOGGER.info(f'{prefix} building FP{16 if builder.platform_has_fast_fp16 and half else 32} engine in {f}')
|
268 |
+
if builder.platform_has_fast_fp16 and half:
|
269 |
config.set_flag(trt.BuilderFlag.FP16)
|
270 |
with builder.build_engine(network, config) as engine, open(f, 'wb') as t:
|
271 |
t.write(engine.serialize())
|