import torch import gradio as gr from torchvision import transforms from torchvision import models from PIL import Image from src.config import Config from src.utils import id2label def predict(image): if image is None: raise gr.Error("No image found. Please upload an image to predict.") transform = transforms.Compose([ transforms.Resize(int(Config.imgsize * 1.143)), transforms.CenterCrop(Config.imgsize), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) model = Config.model try: image = Image.fromarray(image.astype('uint8'), 'RGB') except: raise gr.Error("Image could not be converted to RGB. Please try another image.") preprocessed_img = transform(image).unsqueeze(0) outputs = model(preprocessed_img) outputs = torch.softmax(outputs, dim=1) preds = torch.topk(outputs, 5) labels = {id2label(preds.indices[0][i].item()).capitalize(): preds.values[0][i].item() for i in range(5)} return labels