Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -6,7 +6,22 @@ from PIL import Image
|
|
6 |
learn = torch.load('digit_classifier.pth')
|
7 |
learn.eval() #switch to eval mode
|
8 |
labels = [str(x) for x in range(10)]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
#Define function to reduce image of arbitrary size to 8x8 per model requirements.
|
11 |
def reduce_image_count(image):
|
12 |
output_size = (8, 8)
|
@@ -18,9 +33,11 @@ def reduce_image_count(image):
|
|
18 |
block = image[i*block_size[0]:(i+1)*block_size[0], j*block_size[1]:(j+1)*block_size[1]]
|
19 |
count = np.count_nonzero(block)
|
20 |
output[i, j] = 16 - ((count / (block_size[0] * block_size[1])) * 16)
|
21 |
-
|
22 |
return output
|
23 |
|
|
|
|
|
24 |
def predict(img):
|
25 |
#First take input and reduce it to 8x8 px as the dataset was
|
26 |
pil_image = Image.open(img) #get image
|
|
|
6 |
learn = torch.load('digit_classifier.pth')
|
7 |
learn.eval() #switch to eval mode
|
8 |
labels = [str(x) for x in range(10)]
|
9 |
+
#################################
|
10 |
+
#Define class for importing Model
|
11 |
+
class DigitClassifier(torch.nn.Module):
|
12 |
+
def __init__(self):
|
13 |
+
super().__init__()
|
14 |
+
self.fc1 = torch.nn.Linear(64, 32)
|
15 |
+
self.fc2 = torch.nn.Linear(32, 16)
|
16 |
+
self.fc3 = torch.nn.Linear(16, 10)
|
17 |
|
18 |
+
def forward(self, x):
|
19 |
+
x = x.view(-1, 64)
|
20 |
+
x = torch.relu(self.fc1(x))
|
21 |
+
x = torch.relu(self.fc2(x))
|
22 |
+
x = self.fc3(x)
|
23 |
+
return x
|
24 |
+
#########################################
|
25 |
#Define function to reduce image of arbitrary size to 8x8 per model requirements.
|
26 |
def reduce_image_count(image):
|
27 |
output_size = (8, 8)
|
|
|
33 |
block = image[i*block_size[0]:(i+1)*block_size[0], j*block_size[1]:(j+1)*block_size[1]]
|
34 |
count = np.count_nonzero(block)
|
35 |
output[i, j] = 16 - ((count / (block_size[0] * block_size[1])) * 16)
|
36 |
+
|
37 |
return output
|
38 |
|
39 |
+
#########################################
|
40 |
+
|
41 |
def predict(img):
|
42 |
#First take input and reduce it to 8x8 px as the dataset was
|
43 |
pil_image = Image.open(img) #get image
|