Spaces:
Runtime error
Runtime error
import json | |
import numpy as np | |
import onnxruntime | |
from PIL import Image | |
class BirdApp: | |
def __init__(self): | |
self.onnx_session = onnxruntime.InferenceSession("model4app.onnx") | |
self.img_class_map = get_img_class_map() | |
def predict(self, x): | |
input_tensor = transform_image(x) | |
onnx_inputs = {self.onnx_session.get_inputs()[0].name: input_tensor} | |
img_label = self.onnx_session.run(None, onnx_inputs)[0].argmax() | |
return {'class_id': int(img_label), 'class_name': self.img_class_map[str(img_label)]} | |
def transform_image(infile) -> np.array: | |
image = (Image | |
.open(infile) | |
.resize((224, 224)) | |
) | |
return np.expand_dims(np.array(image, dtype=np.float32), 0).transpose([0, 3, 1, 2]) | |
def get_img_class_map(): | |
with open('index_to_name.json') as f: | |
img_class_map = json.load(f) | |
return img_class_map | |