import torch import torch.nn as nn import torchvision.transforms as transforms import gradio as gr # Define the CNN class SimpleCNN(nn.Module): def __init__(self): super(SimpleCNN, self).__init__() self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1) self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1) self.pool = nn.MaxPool2d(kernel_size=2, stride=2) self.fc1 = nn.Linear(64 * 7 * 7, 128) self.fc2 = nn.Linear(128, 10) self.relu = nn.ReLU() def forward(self, x): x = self.pool(self.relu(self.conv1(x))) # Output: 32x14x14 x = self.pool(self.relu(self.conv2(x))) # Output: 64x7x7 x = x.view(-1, 64 * 7 * 7) # Flattened to: 3136 x = self.relu(self.fc1(x)) # Output: 128 x = self.fc2(x) # Output: 10 logits return x # Load the trained model model = SimpleCNN() model.load_state_dict(torch.load('mnist_cnn.pth', map_location=torch.device('cpu'))) model.eval() # Define the transformation for the input image transform = transforms.Compose([ transforms.Grayscale(num_output_channels=1), transforms.Resize((28, 28)), transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)) ]) # Prediction function def predict(image): try: image = transform(image).unsqueeze(0) # Add batch dimension with torch.no_grad(): output = model(image) probabilities = nn.Softmax(dim=1)(output) predicted_class = torch.argmax(probabilities, dim=1) return {str(i): probabilities[0][i].item() for i in range(10)} except Exception as e: print(f"Error in predict function: {e}") return {"error": str(e)} # Create the Gradio interface interface = gr.Interface( fn=predict, inputs=gr.Sketchpad(), outputs=gr.Label() ) # Launch the interface interface.launch()