Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
import torchvision.transforms as transforms | |
import gradio as gr | |
# Define the CNN | |
class SimpleCNN(nn.Module): | |
def __init__(self): | |
super(SimpleCNN, self).__init__() | |
self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1) | |
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1) | |
self.pool = nn.MaxPool2d(kernel_size=2, stride=2) | |
self.fc1 = nn.Linear(64 * 7 * 7, 128) | |
self.fc2 = nn.Linear(128, 10) | |
self.relu = nn.ReLU() | |
def forward(self, x): | |
x = self.pool(self.relu(self.conv1(x))) # Output: 32x14x14 | |
x = self.pool(self.relu(self.conv2(x))) # Output: 64x7x7 | |
x = x.view(-1, 64 * 7 * 7) # Flattened to: 3136 | |
x = self.relu(self.fc1(x)) # Output: 128 | |
x = self.fc2(x) # Output: 10 logits | |
return x | |
# Load the trained model | |
model = SimpleCNN() | |
model.load_state_dict(torch.load('mnist_cnn.pth', map_location=torch.device('cpu'))) | |
model.eval() | |
# Define the transformation for the input image | |
transform = transforms.Compose([ | |
transforms.Grayscale(num_output_channels=1), | |
transforms.Resize((28, 28)), | |
transforms.ToTensor(), | |
transforms.Normalize((0.5,), (0.5,)) | |
]) | |
# Prediction function | |
def predict(image): | |
try: | |
image = transform(image).unsqueeze(0) # Add batch dimension | |
with torch.no_grad(): | |
output = model(image) | |
probabilities = nn.Softmax(dim=1)(output) | |
predicted_class = torch.argmax(probabilities, dim=1) | |
return {str(i): probabilities[0][i].item() for i in range(10)} | |
except Exception as e: | |
print(f"Error in predict function: {e}") | |
return {"error": str(e)} | |
# Create the Gradio interface | |
interface = gr.Interface( | |
fn=predict, | |
inputs=gr.Sketchpad(), | |
outputs=gr.Label() | |
) | |
# Launch the interface | |
interface.launch() | |