import json
import numpy as np
import onnxruntime
from PIL import Image


class BirdApp:
    def __init__(self):
        self.onnx_session = onnxruntime.InferenceSession("model4app.onnx")
        self.img_class_map = get_img_class_map()

    def predict(self, x):
        input_tensor = transform_image(x)
        onnx_inputs = {self.onnx_session.get_inputs()[0].name: input_tensor}
        img_label = self.onnx_session.run(None, onnx_inputs)[0].argmax()
        return {'class_id': int(img_label), 'class_name': self.img_class_map[str(img_label)]}


def transform_image(infile) -> np.array:
    image = (Image
             .open(infile)
             .resize((224, 224))
             )
    return np.expand_dims(np.array(image, dtype=np.float32), 0).transpose([0, 3, 1, 2])


def get_img_class_map():
    with open('index_to_name.json') as f:
        img_class_map = json.load(f)
    return img_class_map