titipata commited on
Commit
aa08f61
·
1 Parent(s): c3301ff

Initial commit

Browse files
Files changed (3) hide show
  1. app.py +73 -0
  2. requirements.txt +1 -0
  3. thai_digit_net.pth +3 -0
app.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from pathlib import Path
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from PIL import Image
7
+ from torchvision import transforms
8
+ import gradio as gr
9
+
10
+
11
+ transform = transforms.Compose([
12
+ transforms.Resize((28, 28)),
13
+ transforms.Grayscale(),
14
+ transforms.ToTensor()
15
+ ])
16
+ labels = ["๐ (ศูนย์)", "๑ (หนึ่ง)", "๒ (สอง)", "๓ (สาม)", "๔ (สี่)", "๕ (ห้า)", "๖ (หก)", "๗ (เจ็ด)", "๘ (แปด)", "๙ (เก้า)"]
17
+ LABELS = {i:k for i, k in enumerate(labels)} # dictionary of index and label
18
+
19
+
20
+ # Load model using DropoutThaiDigit instead
21
+ class DropoutThaiDigit(nn.Module):
22
+ def __init__(self):
23
+ super(DropoutThaiDigit, self).__init__()
24
+ self.fc1 = nn.Linear(28 * 28, 392)
25
+ self.fc2 = nn.Linear(392, 196)
26
+ self.fc3 = nn.Linear(196, 98)
27
+ self.fc4 = nn.Linear(98, 10)
28
+ self.dropout = nn.Dropout(0.1)
29
+
30
+ def forward(self, x):
31
+ x = x.view(-1, 28 * 28)
32
+ x = self.fc1(x)
33
+ x = F.relu(x)
34
+ x = self.dropout(x)
35
+ x = self.fc2(x)
36
+ x = F.relu(x)
37
+ x = self.dropout(x)
38
+ x = self.fc3(x)
39
+ x = F.relu(x)
40
+ x = self.dropout(x)
41
+ x = self.fc4(x)
42
+ return x
43
+
44
+
45
+ model = DropoutThaiDigit()
46
+ model.load_state_dict(torch.load("thai_digit_net.pth"))
47
+ model.eval()
48
+
49
+
50
+ def predict(img):
51
+ """
52
+ Predict function takes image and return top 5 predictions
53
+ as a dictionary:
54
+
55
+ {label: confidence, label: confidence, ...}
56
+ """
57
+ if img is None:
58
+ return None
59
+ img = transform(img) # do not need to use 1 - transform(img) because gradio already do it
60
+ probs = model(img).softmax(dim=1).ravel()
61
+ probs, indices = torch.topk(probs, 5) # select top 5
62
+ probs, indices = probs.tolist(), indices.tolist() # transform to list
63
+ confidences = {LABELS[i]: v for i, v in zip(indices, probs)}
64
+ return confidences
65
+
66
+
67
+ gr.Interface(
68
+ fn=predict,
69
+ inputs="sketchpad",
70
+ outputs="label",
71
+ title="Thai Digit Handwritten Classification",
72
+ live=True
73
+ ).launch(enable_queue=True)
requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ torch
thai_digit_net.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4b9496e0d1c715adb46e7f30ba3791b95b589c45df168f7566cf3d96d54f3454
3
+ size 1622805