Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -7,51 +7,49 @@ import matplotlib.pyplot as plt
|
|
7 |
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
|
8 |
import PIL.Image
|
9 |
|
10 |
-
# Load the model
|
11 |
-
|
12 |
class MyMnist_ModelV0(nn.Module):
|
13 |
-
|
14 |
super().__init__()
|
15 |
self.layer_stack = nn.Sequential(
|
16 |
-
nn.Flatten(),
|
17 |
-
nn.Linear(in_features=input_shape, out_features=hidden_units),
|
18 |
nn.ReLU(),
|
19 |
nn.Linear(in_features=hidden_units, out_features=hidden_units2),
|
20 |
nn.ReLU(),
|
21 |
nn.Linear(in_features=hidden_units2, out_features=output_shape)
|
22 |
)
|
23 |
|
24 |
-
|
25 |
-
|
26 |
|
27 |
-
#
|
28 |
load_model = MyMnist_ModelV0(input_shape=784,
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
)
|
33 |
|
34 |
-
PATH = "state_dict_model.pth"
|
35 |
|
36 |
load_model.load_state_dict(torch.load(PATH))
|
37 |
load_model.eval()
|
|
|
|
|
38 |
def recognize_digit(image):
|
39 |
if image is not None:
|
40 |
-
#
|
41 |
-
image =
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
image = PIL.ImageOps.invert(image) # Invert colors
|
47 |
-
image = transform(image)
|
48 |
-
with torch.inference_mode(): # inference mode of pytoroch
|
49 |
prediction = load_model(image)
|
50 |
prediction = torch.softmax(prediction, dim=1)
|
51 |
return {str(i): float(prediction[0][i]) for i in range(10)}
|
52 |
else:
|
53 |
return ""
|
54 |
|
|
|
55 |
def create_canvas():
|
56 |
fig, ax = plt.subplots()
|
57 |
ax.set_title("Draw your digit")
|
@@ -60,10 +58,13 @@ def create_canvas():
|
|
60 |
canvas = FigureCanvas(fig)
|
61 |
return canvas
|
62 |
|
|
|
63 |
canvas = create_canvas()
|
64 |
|
|
|
|
|
|
|
|
|
65 |
|
66 |
-
|
67 |
-
|
68 |
-
outputs=gr.Label(num_top_classes=1))
|
69 |
-
demo.launch(share=True)
|
|
|
7 |
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
|
8 |
import PIL.Image
|
9 |
|
10 |
+
# Load the model
|
|
|
11 |
class MyMnist_ModelV0(nn.Module):
|
12 |
+
def __init__(self, input_shape: int, hidden_units: int, hidden_units2: int, output_shape: int):
|
13 |
super().__init__()
|
14 |
self.layer_stack = nn.Sequential(
|
15 |
+
nn.Flatten(),
|
16 |
+
nn.Linear(in_features=input_shape, out_features=hidden_units),
|
17 |
nn.ReLU(),
|
18 |
nn.Linear(in_features=hidden_units, out_features=hidden_units2),
|
19 |
nn.ReLU(),
|
20 |
nn.Linear(in_features=hidden_units2, out_features=output_shape)
|
21 |
)
|
22 |
|
23 |
+
def forward(self, x):
|
24 |
+
return self.layer_stack(x)
|
25 |
|
26 |
+
# Load the pre-trained model
|
27 |
load_model = MyMnist_ModelV0(input_shape=784,
|
28 |
+
hidden_units=256,
|
29 |
+
hidden_units2=128,
|
30 |
+
output_shape=10)
|
|
|
31 |
|
32 |
+
PATH = "state_dict_model.pth" # Path to the trained model
|
33 |
|
34 |
load_model.load_state_dict(torch.load(PATH))
|
35 |
load_model.eval()
|
36 |
+
|
37 |
+
# Function to recognize digit
|
38 |
def recognize_digit(image):
|
39 |
if image is not None:
|
40 |
+
# Convert image to grayscale
|
41 |
+
image = np.array(image.convert("L"))
|
42 |
+
# Resize image to 28x28
|
43 |
+
image = torch.tensor(image / 255.0, dtype=torch.float32)
|
44 |
+
# Perform inference
|
45 |
+
with torch.inference_mode():
|
|
|
|
|
|
|
46 |
prediction = load_model(image)
|
47 |
prediction = torch.softmax(prediction, dim=1)
|
48 |
return {str(i): float(prediction[0][i]) for i in range(10)}
|
49 |
else:
|
50 |
return ""
|
51 |
|
52 |
+
# Function to create a canvas for drawing
|
53 |
def create_canvas():
|
54 |
fig, ax = plt.subplots()
|
55 |
ax.set_title("Draw your digit")
|
|
|
58 |
canvas = FigureCanvas(fig)
|
59 |
return canvas
|
60 |
|
61 |
+
# Create canvas
|
62 |
canvas = create_canvas()
|
63 |
|
64 |
+
# Define Gradio interface
|
65 |
+
demo = gr.Interface(fn=recognize_digit,
|
66 |
+
inputs=gr.inputs.Image(canvas=canvas),
|
67 |
+
outputs=gr.outputs.Label(num_top_classes=1))
|
68 |
|
69 |
+
# Launch the interface
|
70 |
+
demo.launch(share=True)
|
|
|
|