stanimirovb's picture
load with snapshot_download
1364f9f verified
import gradio as gr
import numpy as np
import torch
import torchvision
from torch import nn
from huggingface_hub import snapshot_download
class LeNet(nn.Module):
def __init__(self):
super(LeNet, self).__init__()
self.convs = nn.Sequential(
nn.Conv2d(in_channels=1, out_channels=4, kernel_size=(5, 5)),
nn.Tanh(),
nn.AvgPool2d(2, 2),
nn.Conv2d(in_channels=4, out_channels=12, kernel_size=(5, 5)),
nn.Tanh(),
nn.AvgPool2d(2, 2)
)
self.linear = nn.Sequential(
nn.Linear(4*4*12,10)
)
def forward(self, x):
x = self.convs(x)
x = torch.flatten(x, 1)
return self.linear(x)
@torch.no_grad()
def predict(self, input):
input = input.reshape(1, 1, 28, 28)
out = self(input)
return nn.functional.softmax(out[0], dim = 0)
lenet = LeNet()
lenet_pt = snapshot_download('stanimirovb/ibob-lenet-v1') + '/lenet-v1.pth'
lenet.load_state_dict(torch.load(lenet_pt, map_location='cpu'))
resize = torchvision.transforms.Resize((28, 28), antialias=True)
def on_submit(img):
with torch.no_grad():
img = img['composite'].astype(np.float32)
img = torch.from_numpy(img)
img = resize(img.unsqueeze(0))
result = lenet.predict(img)
sorted = [[i, e] for i, e in enumerate(result.numpy())]
sorted.sort(key = lambda a : -a[1])
return "\n".join(map(str, sorted))
iface = gr.Interface(
title = "LeNet",
fn = on_submit,
inputs=gr.Sketchpad(image_mode='P'),
outputs=gr.Text(),
)
iface.launch()