import gradio as gr import numpy as np import torch from torch import nn from torchvision import transforms import matplotlib.pyplot as plt from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas import PIL.Image # Load the model class MyMnist_ModelV0(nn.Module): def __init__(self, input_shape: int, hidden_units: int, hidden_units2: int, output_shape: int): super().__init__() self.layer_stack = nn.Sequential( nn.Flatten(), nn.Linear(in_features=input_shape, out_features=hidden_units), nn.ReLU(), nn.Linear(in_features=hidden_units, out_features=hidden_units2), nn.ReLU(), nn.Linear(in_features=hidden_units2, out_features=output_shape) ) def forward(self, x): return self.layer_stack(x) # Load the pre-trained model load_model = MyMnist_ModelV0(input_shape=784, hidden_units=256, hidden_units2=128, output_shape=10) PATH = "state_dict_model.pth" # Path to the trained model load_model.load_state_dict(torch.load(PATH)) load_model.eval() # Function to recognize digit def recognize_digit(image): if image is not None: # Convert image to grayscale image = np.array(image.convert("L")) # Resize image to 28x28 image = torch.tensor(image / 255.0, dtype=torch.float32) # Perform inference with torch.inference_mode(): prediction = load_model(image) prediction = torch.softmax(prediction, dim=1) return {str(i): float(prediction[0][i]) for i in range(10)} else: return "" # Function to create a canvas for drawing def create_canvas(): fig, ax = plt.subplots() ax.set_title("Draw your digit") ax.set_xticks([]) ax.set_yticks([]) canvas = FigureCanvas(fig) return canvas # Create canvas canvas = create_canvas() # Define Gradio interface demo = gr.Interface(fn=recognize_digit, inputs=gr.inputs.Image(canvas=canvas), outputs=gr.outputs.Label(num_top_classes=1)) # Launch the interface demo.launch(share=True)