alexluna4 commited on
Commit
32cb560
·
verified ·
1 Parent(s): 72e009a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -26
app.py CHANGED
@@ -7,51 +7,49 @@ import matplotlib.pyplot as plt
7
  from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
8
  import PIL.Image
9
 
10
- # Load the model cause it is necessary for pytorch
11
-
12
  class MyMnist_ModelV0(nn.Module):
13
- def __init__(self, input_shape: int, hidden_units: int, hidden_units2: int, output_shape: int):
14
  super().__init__()
15
  self.layer_stack = nn.Sequential(
16
- nn.Flatten(), # neural networks like their inputs in vector form
17
- nn.Linear(in_features=input_shape, out_features=hidden_units), # in_features = number of features in a data sample (784 pixels)
18
  nn.ReLU(),
19
  nn.Linear(in_features=hidden_units, out_features=hidden_units2),
20
  nn.ReLU(),
21
  nn.Linear(in_features=hidden_units2, out_features=output_shape)
22
  )
23
 
24
- def forward(self, x):
25
- return self.layer_stack(x)
26
 
27
- # instance of the model
28
  load_model = MyMnist_ModelV0(input_shape=784,
29
- hidden_units=256,
30
- hidden_units2=128,
31
- output_shape=10
32
- )
33
 
34
- PATH = "state_dict_model.pth" # PATH where you load the model trained
35
 
36
  load_model.load_state_dict(torch.load(PATH))
37
  load_model.eval()
 
 
38
  def recognize_digit(image):
39
  if image is not None:
40
- # Preprocess of the image
41
- image = Image.fromarray(image).resize((28, 28))
42
- transform = transforms.Compose([
43
- transforms.ToTensor(),
44
- transforms.Normalize((0.5,), (0.5,))
45
- ])
46
- image = PIL.ImageOps.invert(image) # Invert colors
47
- image = transform(image)
48
- with torch.inference_mode(): # inference mode of pytoroch
49
  prediction = load_model(image)
50
  prediction = torch.softmax(prediction, dim=1)
51
  return {str(i): float(prediction[0][i]) for i in range(10)}
52
  else:
53
  return ""
54
 
 
55
  def create_canvas():
56
  fig, ax = plt.subplots()
57
  ax.set_title("Draw your digit")
@@ -60,10 +58,13 @@ def create_canvas():
60
  canvas = FigureCanvas(fig)
61
  return canvas
62
 
 
63
  canvas = create_canvas()
64
 
 
 
 
 
65
 
66
- demo = gr.Interface(fn=recognize_digit,
67
- inputs=gr.Image(image_mode="L", source="canvas", canvas=canvas),
68
- outputs=gr.Label(num_top_classes=1))
69
- demo.launch(share=True)
 
7
  from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
8
  import PIL.Image
9
 
10
+ # Load the model
 
11
  class MyMnist_ModelV0(nn.Module):
12
+ def __init__(self, input_shape: int, hidden_units: int, hidden_units2: int, output_shape: int):
13
  super().__init__()
14
  self.layer_stack = nn.Sequential(
15
+ nn.Flatten(),
16
+ nn.Linear(in_features=input_shape, out_features=hidden_units),
17
  nn.ReLU(),
18
  nn.Linear(in_features=hidden_units, out_features=hidden_units2),
19
  nn.ReLU(),
20
  nn.Linear(in_features=hidden_units2, out_features=output_shape)
21
  )
22
 
23
+ def forward(self, x):
24
+ return self.layer_stack(x)
25
 
26
+ # Load the pre-trained model
27
  load_model = MyMnist_ModelV0(input_shape=784,
28
+ hidden_units=256,
29
+ hidden_units2=128,
30
+ output_shape=10)
 
31
 
32
+ PATH = "state_dict_model.pth" # Path to the trained model
33
 
34
  load_model.load_state_dict(torch.load(PATH))
35
  load_model.eval()
36
+
37
+ # Function to recognize digit
38
  def recognize_digit(image):
39
  if image is not None:
40
+ # Convert image to grayscale
41
+ image = np.array(image.convert("L"))
42
+ # Resize image to 28x28
43
+ image = torch.tensor(image / 255.0, dtype=torch.float32)
44
+ # Perform inference
45
+ with torch.inference_mode():
 
 
 
46
  prediction = load_model(image)
47
  prediction = torch.softmax(prediction, dim=1)
48
  return {str(i): float(prediction[0][i]) for i in range(10)}
49
  else:
50
  return ""
51
 
52
+ # Function to create a canvas for drawing
53
  def create_canvas():
54
  fig, ax = plt.subplots()
55
  ax.set_title("Draw your digit")
 
58
  canvas = FigureCanvas(fig)
59
  return canvas
60
 
61
+ # Create canvas
62
  canvas = create_canvas()
63
 
64
+ # Define Gradio interface
65
+ demo = gr.Interface(fn=recognize_digit,
66
+ inputs=gr.inputs.Image(canvas=canvas),
67
+ outputs=gr.outputs.Label(num_top_classes=1))
68
 
69
+ # Launch the interface
70
+ demo.launch(share=True)