|
import gradio as gr |
|
import numpy as np |
|
import torch |
|
import torchvision |
|
from torch import nn |
|
from huggingface_hub import snapshot_download |
|
|
|
class LeNet(nn.Module): |
|
def __init__(self): |
|
super(LeNet, self).__init__() |
|
|
|
self.convs = nn.Sequential( |
|
nn.Conv2d(in_channels=1, out_channels=4, kernel_size=(5, 5)), |
|
nn.Tanh(), |
|
nn.AvgPool2d(2, 2), |
|
|
|
nn.Conv2d(in_channels=4, out_channels=12, kernel_size=(5, 5)), |
|
nn.Tanh(), |
|
nn.AvgPool2d(2, 2) |
|
) |
|
|
|
self.linear = nn.Sequential( |
|
nn.Linear(4*4*12,10) |
|
) |
|
|
|
def forward(self, x): |
|
x = self.convs(x) |
|
x = torch.flatten(x, 1) |
|
|
|
return self.linear(x) |
|
|
|
@torch.no_grad() |
|
def predict(self, input): |
|
input = input.reshape(1, 1, 28, 28) |
|
out = self(input) |
|
return nn.functional.softmax(out[0], dim = 0) |
|
|
|
lenet = LeNet() |
|
|
|
lenet_pt = snapshot_download('stanimirovb/ibob-lenet-v1') + '/lenet-v1.pth' |
|
lenet.load_state_dict(torch.load(lenet_pt, map_location='cpu')) |
|
|
|
resize = torchvision.transforms.Resize((28, 28), antialias=True) |
|
def on_submit(img): |
|
with torch.no_grad(): |
|
img = img['composite'].astype(np.float32) |
|
img = torch.from_numpy(img) |
|
img = resize(img.unsqueeze(0)) |
|
|
|
result = lenet.predict(img) |
|
|
|
sorted = [[i, e] for i, e in enumerate(result.numpy())] |
|
sorted.sort(key = lambda a : -a[1]) |
|
|
|
return "\n".join(map(str, sorted)) |
|
|
|
iface = gr.Interface( |
|
title = "LeNet", |
|
fn = on_submit, |
|
inputs=gr.Sketchpad(image_mode='P'), |
|
outputs=gr.Text(), |
|
) |
|
iface.launch() |
|
|