DavidD003 commited on
Commit
c0994d8
·
1 Parent(s): 54d2cb7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -17
app.py CHANGED
@@ -3,24 +3,35 @@ from fastai.vision.all import *
3
  from PIL import Image
4
  #
5
  #learn = load_learner('export.pkl')
6
- learn = torch.load('digit_classifier.pth')
7
- learn.eval() #switch to eval mode
 
 
 
 
 
 
 
 
 
 
 
8
  labels = [str(x) for x in range(10)]
9
- #################################
10
- #Define class for importing Model
11
- class DigitClassifier(torch.nn.Module):
12
- def __init__(self):
13
- super().__init__()
14
- self.fc1 = torch.nn.Linear(64, 32)
15
- self.fc2 = torch.nn.Linear(32, 16)
16
- self.fc3 = torch.nn.Linear(16, 10)
17
 
18
- def forward(self, x):
19
- x = x.view(-1, 64)
20
- x = torch.relu(self.fc1(x))
21
- x = torch.relu(self.fc2(x))
22
- x = self.fc3(x)
23
- return x
24
  #########################################
25
  #Define function to reduce image of arbitrary size to 8x8 per model requirements.
26
  def reduce_image_count(image):
@@ -45,7 +56,7 @@ def predict(img):
45
  pic = np.array(gray_img) #convert to array
46
  inp_img=reduce_image_count(pic)#Reduce image to required input size
47
 
48
- otpt=F.softmax(learn.forward(inp_img.view(-1,64)))
49
  #pred,pred_idx,probs = learn.predict(img)
50
 
51
  return dict([[labels[i], float(otpt[0].data[i])] for i in range(len(labels))]),inp_img
 
3
  from PIL import Image
4
  #
5
  #learn = load_learner('export.pkl')
6
+ #learn = torch.load('digit_classifier.pth')
7
+ #learn.eval() #switch to eval mode
8
+ model_dict=torch.load('my_model.pt')
9
+ W1,B1,W2,B2,W3,B3=model_dict['W1'],model_dict['B1'],model_dict['W2'],model_dict['B2'],model_dict['W3'],model_dict['B3']
10
+ def mdlV2(xb):
11
+ res = xb@W1+B1
12
+ res = res.max(tensor(0.))
13
+ res = res@W2+B2 # returns 10 features for each input
14
+ res = res.max(tensor(0.))
15
+ res = res@W3+B3 # returns 10 features for each input
16
+ return res
17
+
18
+
19
  labels = [str(x) for x in range(10)]
20
+ # #################################
21
+ # #Define class for importing Model
22
+ # class DigitClassifier(torch.nn.Module):
23
+ # def __init__(self):
24
+ # super().__init__()
25
+ # self.fc1 = torch.nn.Linear(64, 32)
26
+ # self.fc2 = torch.nn.Linear(32, 16)
27
+ # self.fc3 = torch.nn.Linear(16, 10)
28
 
29
+ # def forward(self, x):
30
+ # x = x.view(-1, 64)
31
+ # x = torch.relu(self.fc1(x))
32
+ # x = torch.relu(self.fc2(x))
33
+ # x = self.fc3(x)
34
+ # return x
35
  #########################################
36
  #Define function to reduce image of arbitrary size to 8x8 per model requirements.
37
  def reduce_image_count(image):
 
56
  pic = np.array(gray_img) #convert to array
57
  inp_img=reduce_image_count(pic)#Reduce image to required input size
58
 
59
+ otpt=F.softmax(mdlV2(inp_img.view(-1,64)))
60
  #pred,pred_idx,probs = learn.predict(img)
61
 
62
  return dict([[labels[i], float(otpt[0].data[i])] for i in range(len(labels))]),inp_img