DavidD003 commited on
Commit
54d2cb7
·
1 Parent(s): 9f90e61

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -1
app.py CHANGED
@@ -6,7 +6,22 @@ from PIL import Image
6
  learn = torch.load('digit_classifier.pth')
7
  learn.eval() #switch to eval mode
8
  labels = [str(x) for x in range(10)]
 
 
 
 
 
 
 
 
9
 
 
 
 
 
 
 
 
10
  #Define function to reduce image of arbitrary size to 8x8 per model requirements.
11
  def reduce_image_count(image):
12
  output_size = (8, 8)
@@ -18,9 +33,11 @@ def reduce_image_count(image):
18
  block = image[i*block_size[0]:(i+1)*block_size[0], j*block_size[1]:(j+1)*block_size[1]]
19
  count = np.count_nonzero(block)
20
  output[i, j] = 16 - ((count / (block_size[0] * block_size[1])) * 16)
21
-
22
  return output
23
 
 
 
24
  def predict(img):
25
  #First take input and reduce it to 8x8 px as the dataset was
26
  pil_image = Image.open(img) #get image
 
6
  learn = torch.load('digit_classifier.pth')
7
  learn.eval() #switch to eval mode
8
  labels = [str(x) for x in range(10)]
9
+ #################################
10
+ #Define class for importing Model
11
+ class DigitClassifier(torch.nn.Module):
12
+ def __init__(self):
13
+ super().__init__()
14
+ self.fc1 = torch.nn.Linear(64, 32)
15
+ self.fc2 = torch.nn.Linear(32, 16)
16
+ self.fc3 = torch.nn.Linear(16, 10)
17
 
18
+ def forward(self, x):
19
+ x = x.view(-1, 64)
20
+ x = torch.relu(self.fc1(x))
21
+ x = torch.relu(self.fc2(x))
22
+ x = self.fc3(x)
23
+ return x
24
+ #########################################
25
  #Define function to reduce image of arbitrary size to 8x8 per model requirements.
26
  def reduce_image_count(image):
27
  output_size = (8, 8)
 
33
  block = image[i*block_size[0]:(i+1)*block_size[0], j*block_size[1]:(j+1)*block_size[1]]
34
  count = np.count_nonzero(block)
35
  output[i, j] = 16 - ((count / (block_size[0] * block_size[1])) * 16)
36
+
37
  return output
38
 
39
+ #########################################
40
+
41
  def predict(img):
42
  #First take input and reduce it to 8x8 px as the dataset was
43
  pil_image = Image.open(img) #get image