Spaces:
Sleeping
Sleeping
import gradio as gr | |
import numpy as np | |
import torch | |
import torchvision | |
from torch import nn | |
from huggingface_hub import snapshot_download | |
class LeNet(nn.Module): | |
def __init__(self): | |
super(LeNet, self).__init__() | |
self.convs = nn.Sequential( | |
nn.Conv2d(in_channels=1, out_channels=4, kernel_size=(5, 5)), | |
nn.Tanh(), | |
nn.AvgPool2d(2, 2), | |
nn.Conv2d(in_channels=4, out_channels=12, kernel_size=(5, 5)), | |
nn.Tanh(), | |
nn.AvgPool2d(2, 2) | |
) | |
self.linear = nn.Sequential( | |
nn.Linear(4*4*12,10) | |
) | |
def forward(self, x): | |
x = self.convs(x) | |
x = torch.flatten(x, 1) | |
return self.linear(x) | |
def predict(self, input): | |
input = input.reshape(1, 1, 28, 28) | |
out = self(input) | |
return nn.functional.softmax(out[0], dim = 0) | |
lenet = LeNet() | |
lenet_pt = snapshot_download('stanimirovb/ibob-lenet-v1') + '/lenet-v1.pth' | |
lenet.load_state_dict(torch.load(lenet_pt, map_location='cpu')) | |
resize = torchvision.transforms.Resize((28, 28), antialias=True) | |
def on_submit(img): | |
with torch.no_grad(): | |
img = img['composite'].astype(np.float32) | |
img = torch.from_numpy(img) | |
img = resize(img.unsqueeze(0)) | |
result = lenet.predict(img) | |
sorted = [[i, e] for i, e in enumerate(result.numpy())] | |
sorted.sort(key = lambda a : -a[1]) | |
return "\n".join(map(str, sorted)) | |
iface = gr.Interface( | |
title = "LeNet", | |
fn = on_submit, | |
inputs=gr.Sketchpad(image_mode='P'), | |
outputs=gr.Text(), | |
) | |
iface.launch() | |