mnist-test / app.py
jiew's picture
Upload 5 files
9f381a3
raw
history blame
732 Bytes
import gradio as gr
import cv2
import torch
from torchvision import datasets, transforms
from Model import MNIST
def mnist(image):
img = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
transf = transforms.ToTensor()
img_tensor = torch.unsqueeze(transf(img), dim=0)
res = model(img_tensor)
res = res.detach().numpy()
return "the result is: " + str(res.argmax())
if __name__ == "__main__":
device = torch.device('cpu')
model = MNIST().to(device)
model.load_state_dict(torch.load('mnist.pkl' , map_location=device))
myapp = gr.Interface(fn=mnist, inputs=gr.Image(shape=(28,28)), outputs="text",title="手写数字识别", description="请点击上传图片或选择下方样例",examples=['5.png','2.png','7.png'])
myapp.launch()