Spaces:
Sleeping
Sleeping
import torch | |
import gradio as gr | |
from torchvision import transforms | |
from torchvision import models | |
from PIL import Image | |
from src.config import Config | |
from src.utils import id2label | |
def predict(image): | |
if image is None: | |
raise gr.Error("No image found. Please upload an image to predict.") | |
transform = transforms.Compose([ | |
transforms.Resize(int(Config.imgsize * 1.143)), | |
transforms.CenterCrop(Config.imgsize), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
]) | |
model = Config.model | |
try: | |
image = Image.fromarray(image.astype('uint8'), 'RGB') | |
except: | |
raise gr.Error("Image could not be converted to RGB. Please try another image.") | |
preprocessed_img = transform(image).unsqueeze(0) | |
outputs = model(preprocessed_img) | |
outputs = torch.softmax(outputs, dim=1) | |
preds = torch.topk(outputs, 5) | |
labels = {id2label(preds.indices[0][i].item()).capitalize(): preds.values[0][i].item() for i in range(5)} | |
return labels | |