import gradio as gr import cv2 import torch from torchvision import datasets, transforms from Model import MNIST def mnist(image): img = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) transf = transforms.ToTensor() img_tensor = torch.unsqueeze(transf(img), dim=0) res = model(img_tensor) res = res.detach().numpy() return "the result is: " + str(res.argmax()) if __name__ == "__main__": device = torch.device('cpu') model = MNIST().to(device) model.load_state_dict(torch.load('mnist.pkl' , map_location=device)) myapp = gr.Interface(fn=mnist, inputs=gr.Image(shape=(28,28)), outputs="text",title="手写数字识别", description="请点击上传图片或选择下方样例",examples=['5.png','2.png','7.png']) myapp.launch()