Spaces:
Runtime error
Runtime error
import torch | |
from torch import nn | |
import torch.nn.functional as F | |
from torchvision.transforms import ToTensor | |
# Define model | |
class ConvNet(nn.Module): | |
def __init__(self): | |
super(ConvNet, self).__init__() | |
self.conv1 = nn.Conv2d(1, 32, kernel_size=5) | |
self.conv2 = nn.Conv2d(32, 32, kernel_size=5) | |
self.conv3 = nn.Conv2d(32,64, kernel_size=5) | |
self.fc1 = nn.Linear(3*3*64, 256) | |
self.fc2 = nn.Linear(256, 10) | |
def forward(self, x): | |
x = F.relu(self.conv1(x)) | |
#x = F.dropout(x, p=0.5, training=self.training) | |
x = F.relu(F.max_pool2d(self.conv2(x), 2)) | |
x = F.dropout(x, p=0.5, training=self.training) | |
x = F.relu(F.max_pool2d(self.conv3(x),2)) | |
x = F.dropout(x, p=0.5, training=self.training) | |
x = x.view(-1,3*3*64 ) | |
x = F.relu(self.fc1(x)) | |
x = F.dropout(x, training=self.training) | |
logits = self.fc2(x) | |
return logits | |
model = ConvNet() | |
model.load_state_dict( | |
torch.load("weights/mnist_convnet_model.pth", | |
map_location=torch.device('cpu')) | |
) | |
model.eval() | |
import gradio as gr | |
from torchvision import transforms | |
def predict(image): | |
tsr_image = transforms.ToTensor()(image) | |
with torch.no_grad(): | |
pred = model(tsr_image) | |
prob = torch.nn.functional.softmax(pred[0], dim=0) | |
confidences = {i: float(prob[i]) for i in range(10)} | |
return confidences | |
with gr.Blocks(css=".gradio-container {background:honeydew;}", title="MNIST ει‘ε¨" | |
) as demo: | |
gr.HTML("""<div style="font-family:'Times New Roman', 'Serif'; font-size:16pt; font-weight:bold; text-align:center; color:royalblue;">MNIST ει‘ε¨</div>""") | |
with gr.Row(): | |
with gr.Tab("γγ£γ³γγΉ"): | |
input_image1 = gr.Image(label="γΉγ±γγ", source="canvas", type="pil", image_mode="L", shape=(28,28), invert_colors=True) | |
send_btn1 = gr.Button("ζ¨θ«γγ") | |
with gr.Tab("η»εγγ‘γ€γ«"): | |
input_image2 = gr.Image(label="η»εε ₯ε", type="pil", image_mode="L", shape=(28, 28), invert_colors=True) | |
send_btn2 = gr.Button("ζ¨θ«γγ") | |
gr.Examples(['examples/sample02.png', 'examples/sample04.png'], inputs=input_image2) | |
output_label=gr.Label(label="ζ¨θ«η’Ίη", num_top_classes=3) | |
send_btn1.click(fn=predict, inputs=input_image1, outputs=output_label) | |
send_btn2.click(fn=predict, inputs=input_image2, outputs=output_label) | |
# demo.queue(concurrency_count=3) | |
demo.launch() | |
### EOF ### |