jiew commited on
Commit
9f381a3
·
1 Parent(s): fa1134d

Upload 5 files

Browse files
Files changed (5) hide show
  1. 2.png +0 -0
  2. 5.png +0 -0
  3. 7.png +0 -0
  4. Model.py +20 -0
  5. app.py +20 -0
2.png ADDED
5.png ADDED
7.png ADDED
Model.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ class MNIST(torch.nn.Module):
4
+ def __init__(self):
5
+ super(MNIST, self).__init__()
6
+ self.conv = torch.nn.Sequential(torch.nn.Conv2d(1, 32, 3, 1, 1),
7
+ torch.nn.ReLU(),
8
+ torch.nn.Conv2d(32, 64, 3, 1, 1),
9
+ torch.nn.ReLU(),
10
+ torch.nn.MaxPool2d(2, 2))
11
+ self.dense = torch.nn.Sequential(torch.nn.Linear(14 * 14 * 64, 1024),
12
+ torch.nn.ReLU(),
13
+ torch.nn.Dropout(p=0.2),
14
+ torch.nn.Linear(1024, 10))
15
+
16
+ def forward(self, x):
17
+ x = self.conv(x)
18
+ x = x.view(-1, 14 * 14 * 64)
19
+ x = self.dense(x)
20
+ return x
app.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import cv2
3
+ import torch
4
+ from torchvision import datasets, transforms
5
+ from Model import MNIST
6
+
7
+ def mnist(image):
8
+ img = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
9
+ transf = transforms.ToTensor()
10
+ img_tensor = torch.unsqueeze(transf(img), dim=0)
11
+ res = model(img_tensor)
12
+ res = res.detach().numpy()
13
+ return "the result is: " + str(res.argmax())
14
+
15
+ if __name__ == "__main__":
16
+ device = torch.device('cpu')
17
+ model = MNIST().to(device)
18
+ model.load_state_dict(torch.load('mnist.pkl' , map_location=device))
19
+ myapp = gr.Interface(fn=mnist, inputs=gr.Image(shape=(28,28)), outputs="text",title="手写数字识别", description="请点击上传图片或选择下方样例",examples=['5.png','2.png','7.png'])
20
+ myapp.launch()