Spaces:
Runtime error
Runtime error
Upload 5 files
Browse files
2.png
ADDED
![]() |
5.png
ADDED
![]() |
7.png
ADDED
![]() |
Model.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
class MNIST(torch.nn.Module):
|
4 |
+
def __init__(self):
|
5 |
+
super(MNIST, self).__init__()
|
6 |
+
self.conv = torch.nn.Sequential(torch.nn.Conv2d(1, 32, 3, 1, 1),
|
7 |
+
torch.nn.ReLU(),
|
8 |
+
torch.nn.Conv2d(32, 64, 3, 1, 1),
|
9 |
+
torch.nn.ReLU(),
|
10 |
+
torch.nn.MaxPool2d(2, 2))
|
11 |
+
self.dense = torch.nn.Sequential(torch.nn.Linear(14 * 14 * 64, 1024),
|
12 |
+
torch.nn.ReLU(),
|
13 |
+
torch.nn.Dropout(p=0.2),
|
14 |
+
torch.nn.Linear(1024, 10))
|
15 |
+
|
16 |
+
def forward(self, x):
|
17 |
+
x = self.conv(x)
|
18 |
+
x = x.view(-1, 14 * 14 * 64)
|
19 |
+
x = self.dense(x)
|
20 |
+
return x
|
app.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import cv2
|
3 |
+
import torch
|
4 |
+
from torchvision import datasets, transforms
|
5 |
+
from Model import MNIST
|
6 |
+
|
7 |
+
def mnist(image):
|
8 |
+
img = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
9 |
+
transf = transforms.ToTensor()
|
10 |
+
img_tensor = torch.unsqueeze(transf(img), dim=0)
|
11 |
+
res = model(img_tensor)
|
12 |
+
res = res.detach().numpy()
|
13 |
+
return "the result is: " + str(res.argmax())
|
14 |
+
|
15 |
+
if __name__ == "__main__":
|
16 |
+
device = torch.device('cpu')
|
17 |
+
model = MNIST().to(device)
|
18 |
+
model.load_state_dict(torch.load('mnist.pkl' , map_location=device))
|
19 |
+
myapp = gr.Interface(fn=mnist, inputs=gr.Image(shape=(28,28)), outputs="text",title="手写数å—识别", description="è¯·ç‚¹å‡»ä¸Šä¼ å›¾ç‰‡æˆ–é€‰æ‹©ä¸‹æ–¹æ ·ä¾‹",examples=['5.png','2.png','7.png'])
|
20 |
+
myapp.launch()
|