|
import torch |
|
import torchvision.transforms as transforms |
|
import gradio as gr |
|
from PIL import Image |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
def get_model_name(name, batch_size, learning_rate, epoch): |
|
""" Generate a name for the model consisting of all the hyperparameter values |
|
|
|
Args: |
|
config: Configuration object containing the hyperparameters |
|
Returns: |
|
path: A string with the hyperparameter name and value concatenated |
|
""" |
|
path = "model_{0}_bs{1}_lr{2}_epoch{3}".format(name, |
|
batch_size, |
|
learning_rate, |
|
epoch) |
|
return path |
|
|
|
class LargeNet(nn.Module): |
|
def __init__(self): |
|
super(LargeNet, self).__init__() |
|
self.name = "large" |
|
self.conv1 = nn.Conv2d(3, 5, 5) |
|
self.pool = nn.MaxPool2d(2, 2) |
|
self.conv2 = nn.Conv2d(5, 10, 5) |
|
self.fc1 = nn.Linear(10 * 29 * 29, 32) |
|
self.fc2 = nn.Linear(32, 8) |
|
|
|
def forward(self, x): |
|
x = self.pool(F.relu(self.conv1(x))) |
|
x = self.pool(F.relu(self.conv2(x))) |
|
x = x.view(-1, 10 * 29 * 29) |
|
x = F.relu(self.fc1(x)) |
|
x = self.fc2(x) |
|
x = x.squeeze(1) |
|
return x |
|
|
|
transform = transforms.Compose([ |
|
transforms.Resize((128, 128)), |
|
transforms.ToTensor(), |
|
transforms.Normalize((0.5,), (0.5,)) |
|
]) |
|
|
|
def load_model(): |
|
net = LargeNet() |
|
model_path = get_model_name(net.name, batch_size=128, learning_rate=0.001, epoch=29) |
|
state = torch.load(model_path) |
|
net.load_state_dict(state) |
|
|
|
net.eval() |
|
return net |
|
|
|
class_names = ["Gasoline_Can", "Pebbels", "pliers", "Screw_Driver", "Toolbox", "Wrench", "other"] |
|
|
|
|
|
def predict(image): |
|
model = load_model() |
|
image = transform(image).unsqueeze(0) |
|
with torch.no_grad(): |
|
output = model(image) |
|
_, pred = torch.max(output, 1) |
|
return class_names[pred.item()] |
|
|
|
interface = gr.Interface( |
|
fn=predict, |
|
inputs=gr.Image(type="pil"), |
|
outputs="label", |
|
title="Mechanical Tools Classifier", |
|
description="Upload an image to classify it as one of the mechanical tools." |
|
) |
|
|
|
if __name__ == "__main__": |
|
interface.launch() |
|
|