birds-israel / bird_app.py
Tom Beer
change onnx file name
109da02
raw
history blame
906 Bytes
import json
import numpy as np
import onnxruntime
from PIL import Image
class BirdApp:
def __init__(self):
self.onnx_session = onnxruntime.InferenceSession("model4app.onnx")
self.img_class_map = get_img_class_map()
def predict(self, x):
input_tensor = transform_image(x)
onnx_inputs = {self.onnx_session.get_inputs()[0].name: input_tensor}
img_label = self.onnx_session.run(None, onnx_inputs)[0].argmax()
return {'class_id': int(img_label), 'class_name': self.img_class_map[str(img_label)]}
def transform_image(infile) -> np.array:
image = (Image
.open(infile)
.resize((224, 224))
)
return np.expand_dims(np.array(image, dtype=np.float32), 0).transpose([0, 3, 1, 2])
def get_img_class_map():
with open('index_to_name.json') as f:
img_class_map = json.load(f)
return img_class_map