alexluna4's picture
Update app.py
32cb560 verified
import gradio as gr
import numpy as np
import torch
from torch import nn
from torchvision import transforms
import matplotlib.pyplot as plt
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
import PIL.Image
# Load the model
class MyMnist_ModelV0(nn.Module):
def __init__(self, input_shape: int, hidden_units: int, hidden_units2: int, output_shape: int):
super().__init__()
self.layer_stack = nn.Sequential(
nn.Flatten(),
nn.Linear(in_features=input_shape, out_features=hidden_units),
nn.ReLU(),
nn.Linear(in_features=hidden_units, out_features=hidden_units2),
nn.ReLU(),
nn.Linear(in_features=hidden_units2, out_features=output_shape)
)
def forward(self, x):
return self.layer_stack(x)
# Load the pre-trained model
load_model = MyMnist_ModelV0(input_shape=784,
hidden_units=256,
hidden_units2=128,
output_shape=10)
PATH = "state_dict_model.pth" # Path to the trained model
load_model.load_state_dict(torch.load(PATH))
load_model.eval()
# Function to recognize digit
def recognize_digit(image):
if image is not None:
# Convert image to grayscale
image = np.array(image.convert("L"))
# Resize image to 28x28
image = torch.tensor(image / 255.0, dtype=torch.float32)
# Perform inference
with torch.inference_mode():
prediction = load_model(image)
prediction = torch.softmax(prediction, dim=1)
return {str(i): float(prediction[0][i]) for i in range(10)}
else:
return ""
# Function to create a canvas for drawing
def create_canvas():
fig, ax = plt.subplots()
ax.set_title("Draw your digit")
ax.set_xticks([])
ax.set_yticks([])
canvas = FigureCanvas(fig)
return canvas
# Create canvas
canvas = create_canvas()
# Define Gradio interface
demo = gr.Interface(fn=recognize_digit,
inputs=gr.inputs.Image(canvas=canvas),
outputs=gr.outputs.Label(num_top_classes=1))
# Launch the interface
demo.launch(share=True)