OpenOCR-Demo / tools /infer /onnx_engine.py
topdu's picture
openocr demo
29f689c
raw
history blame
2 kB
import os
import onnxruntime
class ONNXEngine:
def __init__(self, onnx_path, use_gpu):
"""
:param onnx_path:
"""
if not os.path.exists(onnx_path):
raise Exception(f'{onnx_path} is not exists')
providers = ['CPUExecutionProvider']
if use_gpu:
providers = ([
'TensorrtExecutionProvider',
'CUDAExecutionProvider',
'CPUExecutionProvider',
], )
self.onnx_session = onnxruntime.InferenceSession(onnx_path,
providers=providers)
self.input_name = self.get_input_name(self.onnx_session)
self.output_name = self.get_output_name(self.onnx_session)
def get_output_name(self, onnx_session):
"""
output_name = onnx_session.get_outputs()[0].name
:param onnx_session:
:return:
"""
output_name = []
for node in onnx_session.get_outputs():
output_name.append(node.name)
return output_name
def get_input_name(self, onnx_session):
"""
input_name = onnx_session.get_inputs()[0].name
:param onnx_session:
:return:
"""
input_name = []
for node in onnx_session.get_inputs():
input_name.append(node.name)
return input_name
def get_input_feed(self, input_name, image_numpy):
"""
input_feed={self.input_name: image_numpy}
:param input_name:
:param image_numpy:
:return:
"""
input_feed = {}
for name in input_name:
input_feed[name] = image_numpy
return input_feed
def run(self, image_numpy):
# 输入数据的类型必须与模型一致,以下三种写法都是可以的
input_feed = self.get_input_feed(self.input_name, image_numpy)
result = self.onnx_session.run(self.output_name, input_feed=input_feed)
return result