ilyi
Refine Gradio App.
da43dfa
import torch
import gradio as gr
from torchvision import transforms
from torchvision import models
from PIL import Image
from src.config import Config
from src.utils import id2label
def predict(image):
if image is None:
raise gr.Error("No image found. Please upload an image to predict.")
transform = transforms.Compose([
transforms.Resize(int(Config.imgsize * 1.143)),
transforms.CenterCrop(Config.imgsize),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
model = Config.model
try:
image = Image.fromarray(image.astype('uint8'), 'RGB')
except:
raise gr.Error("Image could not be converted to RGB. Please try another image.")
preprocessed_img = transform(image).unsqueeze(0)
outputs = model(preprocessed_img)
outputs = torch.softmax(outputs, dim=1)
preds = torch.topk(outputs, 5)
labels = {id2label(preds.indices[0][i].item()).capitalize(): preds.values[0][i].item() for i in range(5)}
return labels