alexluna4 commited on
Commit
dcbb1ee
·
verified ·
1 Parent(s): 93baa5a

Upload inferences.py

Browse files
Files changed (1) hide show
  1. inferences.py +53 -0
inferences.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from torch import nn
4
+ from torchvision import transforms
5
+
6
+ # Load the model cause it is necessary for pytorch
7
+
8
+ class MyMnist_ModelV0(nn.Module):
9
+ def __init__(self, input_shape: int, hidden_units: int, hidden_units2: int, output_shape: int):
10
+ super().__init__()
11
+ self.layer_stack = nn.Sequential(
12
+ nn.Flatten(), # neural networks like their inputs in vector form
13
+ nn.Linear(in_features=input_shape, out_features=hidden_units), # in_features = number of features in a data sample (784 pixels)
14
+ nn.ReLU(),
15
+ nn.Linear(in_features=hidden_units, out_features=hidden_units2),
16
+ nn.ReLU(),
17
+ nn.Linear(in_features=hidden_units2, out_features=output_shape)
18
+ )
19
+
20
+ def forward(self, x):
21
+ return self.layer_stack(x)
22
+
23
+ # instance of the model
24
+ load_model = MyMnist_ModelV0(input_shape=784,
25
+ hidden_units=256,
26
+ hidden_units2=128,
27
+ output_shape=10
28
+ )
29
+
30
+ PATH = "C:/Users/Hp Pavilion/Downloads/state_dict_model.pth" # PATH where you load the model trained
31
+
32
+ load_model.load_state_dict(torch.load(PATH))
33
+ load_model.eval()
34
+ def recognize_digit(image):
35
+ if image is not None:
36
+ # Preprocess of the image
37
+ transform = transforms.Compose([
38
+ transforms.ToTensor(),
39
+ transforms.Normalize((0.5,), (0.5,))
40
+ ])
41
+ image = transform(image)
42
+ with torch.inference_mode(): # inference mode of pytoroch
43
+ prediction = load_model(image)
44
+ prediction = torch.softmax(prediction, dim=1)
45
+ return {str(i): float(prediction[0][i]) for i in range(10)}
46
+ else:
47
+ return ""
48
+
49
+
50
+ demo = gr.Interface(fn=recognize_digit,
51
+ inputs=gr.Image(shape=(28,28), image_mode="L", invert_colors=True, source="canvas"),
52
+ outputs=gr.Label(num_top_classes=1))
53
+ demo.launch(True)