Spaces:
Running
Running
import os | |
import onnxruntime | |
class ONNXEngine: | |
def __init__(self, onnx_path, use_gpu): | |
""" | |
:param onnx_path: | |
""" | |
if not os.path.exists(onnx_path): | |
raise Exception(f'{onnx_path} is not exists') | |
providers = ['CPUExecutionProvider'] | |
if use_gpu: | |
providers = ([ | |
'TensorrtExecutionProvider', | |
'CUDAExecutionProvider', | |
'CPUExecutionProvider', | |
], ) | |
self.onnx_session = onnxruntime.InferenceSession(onnx_path, | |
providers=providers) | |
self.input_name = self.get_input_name(self.onnx_session) | |
self.output_name = self.get_output_name(self.onnx_session) | |
def get_output_name(self, onnx_session): | |
""" | |
output_name = onnx_session.get_outputs()[0].name | |
:param onnx_session: | |
:return: | |
""" | |
output_name = [] | |
for node in onnx_session.get_outputs(): | |
output_name.append(node.name) | |
return output_name | |
def get_input_name(self, onnx_session): | |
""" | |
input_name = onnx_session.get_inputs()[0].name | |
:param onnx_session: | |
:return: | |
""" | |
input_name = [] | |
for node in onnx_session.get_inputs(): | |
input_name.append(node.name) | |
return input_name | |
def get_input_feed(self, input_name, image_numpy): | |
""" | |
input_feed={self.input_name: image_numpy} | |
:param input_name: | |
:param image_numpy: | |
:return: | |
""" | |
input_feed = {} | |
for name in input_name: | |
input_feed[name] = image_numpy | |
return input_feed | |
def run(self, image_numpy): | |
# 输入数据的类型必须与模型一致,以下三种写法都是可以的 | |
input_feed = self.get_input_feed(self.input_name, image_numpy) | |
result = self.onnx_session.run(self.output_name, input_feed=input_feed) | |
return result | |