Spaces:
Runtime error
Runtime error
import gradio as gr | |
import numpy as np | |
import torch | |
from torch import nn | |
from torchvision import transforms | |
import matplotlib.pyplot as plt | |
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas | |
import PIL.Image | |
# Load the model | |
class MyMnist_ModelV0(nn.Module): | |
def __init__(self, input_shape: int, hidden_units: int, hidden_units2: int, output_shape: int): | |
super().__init__() | |
self.layer_stack = nn.Sequential( | |
nn.Flatten(), | |
nn.Linear(in_features=input_shape, out_features=hidden_units), | |
nn.ReLU(), | |
nn.Linear(in_features=hidden_units, out_features=hidden_units2), | |
nn.ReLU(), | |
nn.Linear(in_features=hidden_units2, out_features=output_shape) | |
) | |
def forward(self, x): | |
return self.layer_stack(x) | |
# Load the pre-trained model | |
load_model = MyMnist_ModelV0(input_shape=784, | |
hidden_units=256, | |
hidden_units2=128, | |
output_shape=10) | |
PATH = "state_dict_model.pth" # Path to the trained model | |
load_model.load_state_dict(torch.load(PATH)) | |
load_model.eval() | |
# Function to recognize digit | |
def recognize_digit(image): | |
if image is not None: | |
# Convert image to grayscale | |
image = np.array(image.convert("L")) | |
# Resize image to 28x28 | |
image = torch.tensor(image / 255.0, dtype=torch.float32) | |
# Perform inference | |
with torch.inference_mode(): | |
prediction = load_model(image) | |
prediction = torch.softmax(prediction, dim=1) | |
return {str(i): float(prediction[0][i]) for i in range(10)} | |
else: | |
return "" | |
# Function to create a canvas for drawing | |
def create_canvas(): | |
fig, ax = plt.subplots() | |
ax.set_title("Draw your digit") | |
ax.set_xticks([]) | |
ax.set_yticks([]) | |
canvas = FigureCanvas(fig) | |
return canvas | |
# Create canvas | |
canvas = create_canvas() | |
# Define Gradio interface | |
demo = gr.Interface(fn=recognize_digit, | |
inputs=gr.inputs.Image(canvas=canvas), | |
outputs=gr.outputs.Label(num_top_classes=1)) | |
# Launch the interface | |
demo.launch(share=True) | |