fulouma's picture
update description
62caad2
import gradio as gr
import torch
import torchvision.transforms as transforms
import torch.nn as nn
import numpy as np
def predict(model, image, device):
# Preprocess the image
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
img_tensor = transform(image).unsqueeze(0)
pred_fun = torch.nn.Softmax(dim=1)
# preds = []
with torch.set_grad_enabled(False):
y = pred_fun(model(img_tensor))
print(y)
y = y.cpu().numpy()
print(y)
y = y[:, 1] # cat:0, dog: 1
print(y)
y = y[0]
print(y)
# preds.append(y)
# preds = np.concatenate(preds)
return {"tenka ippin":y, "no entry":1-y}
# return preds
def process_image(input_image):
model = torch.load('models/tenichi_noentry.pth')
preds = predict(model, input_image, 'cpu')
return preds
iface = gr.Interface(
fn=process_image,
inputs=[
gr.Image(type="pil", label="Input Image", height=512),
],
outputs=gr.Label(label="Output", show_label=False),
description="画像に映っているのが天下一品のロゴなのか、進入禁止標識なのか判別します",
examples=[
["examples/ten20.png"],
["examples/noe33.png"],
],
# run_on_click=False,
# cache_examples=True
)
if __name__ == "__main__":
iface.launch()
# demo = gr.Interface(fn=greet, inputs="text", outputs="text")
# demo.launch()