Spaces:
Sleeping
Sleeping
File size: 1,063 Bytes
da43dfa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 |
import torch
import gradio as gr
from torchvision import transforms
from torchvision import models
from PIL import Image
from src.config import Config
from src.utils import id2label
def predict(image):
if image is None:
raise gr.Error("No image found. Please upload an image to predict.")
transform = transforms.Compose([
transforms.Resize(int(Config.imgsize * 1.143)),
transforms.CenterCrop(Config.imgsize),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
model = Config.model
try:
image = Image.fromarray(image.astype('uint8'), 'RGB')
except:
raise gr.Error("Image could not be converted to RGB. Please try another image.")
preprocessed_img = transform(image).unsqueeze(0)
outputs = model(preprocessed_img)
outputs = torch.softmax(outputs, dim=1)
preds = torch.topk(outputs, 5)
labels = {id2label(preds.indices[0][i].item()).capitalize(): preds.values[0][i].item() for i in range(5)}
return labels
|