Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
import torchvision.transforms as transforms | |
import torch.nn as nn | |
import numpy as np | |
def predict(model, image, device): | |
# Preprocess the image | |
transform = transforms.Compose([ | |
transforms.Resize(256), | |
transforms.CenterCrop(224), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
]) | |
img_tensor = transform(image).unsqueeze(0) | |
pred_fun = torch.nn.Softmax(dim=1) | |
# preds = [] | |
with torch.set_grad_enabled(False): | |
y = pred_fun(model(img_tensor)) | |
print(y) | |
y = y.cpu().numpy() | |
print(y) | |
y = y[:, 1] # cat:0, dog: 1 | |
print(y) | |
y = y[0] | |
print(y) | |
# preds.append(y) | |
# preds = np.concatenate(preds) | |
return {"tenka ippin":y, "no entry":1-y} | |
# return preds | |
def process_image(input_image): | |
model = torch.load('models/tenichi_noentry.pth') | |
preds = predict(model, input_image, 'cpu') | |
return preds | |
iface = gr.Interface( | |
fn=process_image, | |
inputs=[ | |
gr.Image(type="pil", label="Input Image", height=512), | |
], | |
outputs=gr.Label(label="Output", show_label=False), | |
description="画像に映っているのが天下一品のロゴなのか、進入禁止標識なのか判別します", | |
examples=[ | |
["examples/ten20.png"], | |
["examples/noe33.png"], | |
], | |
# run_on_click=False, | |
# cache_examples=True | |
) | |
if __name__ == "__main__": | |
iface.launch() | |
# demo = gr.Interface(fn=greet, inputs="text", outputs="text") | |
# demo.launch() | |