File size: 732 Bytes
9f381a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import gradio as gr
import cv2
import torch
from torchvision import datasets, transforms
from Model import MNIST

def mnist(image):
	img = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
	transf = transforms.ToTensor()
	img_tensor = torch.unsqueeze(transf(img), dim=0)
	res = model(img_tensor)
	res = res.detach().numpy()
	return "the result is: " + str(res.argmax())

if __name__ == "__main__":
	device = torch.device('cpu')
	model = MNIST().to(device)
	model.load_state_dict(torch.load('mnist.pkl' , map_location=device))
	myapp = gr.Interface(fn=mnist, inputs=gr.Image(shape=(28,28)), outputs="text",title="手写数字识别", description="请点击上传图片或选择下方样例",examples=['5.png','2.png','7.png'])
	myapp.launch()