mnist-cnn / model.py
quantumbit's picture
Create model.py
8eb129d verified
import onnxruntime as ort
import numpy as np
class Model:
def __init__(self):
self.session = ort.InferenceSession("mnist_cnn.onnx")
def predict(self, inputs):
inputs = np.array(inputs).astype(np.float32).reshape(1, 28, 28, 1)
outputs = self.session.run(None, {"input": inputs})
return {"digit": int(np.argmax(outputs[0])), "confidence_scores": outputs[0].tolist()}