|
import gradio as gr |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import torchvision.transforms as transforms |
|
from PIL import Image |
|
import numpy as np |
|
|
|
|
|
class ConvNet(nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
self.conv1 = nn.Conv2d(3, 32, 3) |
|
self.pool = nn.MaxPool2d(2, 2) |
|
self.conv2 = nn.Conv2d(32, 64, 3) |
|
self.conv3 = nn.Conv2d(64, 64, 3) |
|
self.fc1 = nn.Linear(64 * 4 * 4, 64) |
|
self.fc2 = nn.Linear(64, 10) |
|
|
|
def forward(self, x): |
|
x = F.relu(self.conv1(x)) |
|
x = self.pool(x) |
|
x = F.relu(self.conv2(x)) |
|
x = self.pool(x) |
|
x = F.relu(self.conv3(x)) |
|
x = torch.flatten(x, 1) |
|
x = F.relu(self.fc1(x)) |
|
x = self.fc2(x) |
|
return x |
|
|
|
|
|
model = ConvNet() |
|
model.load_state_dict(torch.load('cnn.pth', map_location=torch.device('cpu'))) |
|
model.eval() |
|
|
|
|
|
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck') |
|
|
|
|
|
transform = transforms.Compose([ |
|
transforms.Resize((32, 32)), |
|
transforms.ToTensor(), |
|
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) |
|
]) |
|
|
|
def predict(img): |
|
if img is None: |
|
return None |
|
|
|
|
|
if not isinstance(img, Image.Image): |
|
img = Image.fromarray(img) |
|
|
|
|
|
img = transform(img).unsqueeze(0) |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = model(img) |
|
probabilities = F.softmax(outputs, dim=1)[0] |
|
|
|
predictions = { |
|
classes[i]: float(probabilities[i]) * 100 |
|
for i in range(len(classes)) |
|
} |
|
|
|
|
|
sorted_predictions = dict(sorted(predictions.items(), key=lambda x: x[1], reverse=True)) |
|
|
|
return sorted_predictions |
|
|
|
iface = gr.Interface( |
|
fn=predict, |
|
inputs=gr.Image(type="pil"), |
|
outputs=gr.Label(num_top_classes=10), |
|
examples=[["example1.jpeg"], ["example2.jpeg"]], |
|
title="CIFAR-10 Image Classifier", |
|
description="Upload an image to classify it into one of these categories: plane, car, bird, cat, deer, dog, frog, horse, ship, or truck. Results show prediction confidence for all classes as percentages." |
|
) |
|
|
|
iface.launch() |