File size: 1,063 Bytes
da43dfa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import torch
import gradio as gr
from torchvision import transforms
from torchvision import models
from PIL import Image

from src.config import Config
from src.utils import id2label


def predict(image):
    if image is None:
        raise gr.Error("No image found. Please upload an image to predict.")

    transform = transforms.Compose([
        transforms.Resize(int(Config.imgsize * 1.143)),
        transforms.CenterCrop(Config.imgsize),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    model = Config.model
    try:
        image = Image.fromarray(image.astype('uint8'), 'RGB')
    except:
        raise gr.Error("Image could not be converted to RGB. Please try another image.")
    preprocessed_img = transform(image).unsqueeze(0)
    outputs = model(preprocessed_img)
    outputs = torch.softmax(outputs, dim=1)
    preds = torch.topk(outputs, 5)
    labels = {id2label(preds.indices[0][i].item()).capitalize(): preds.values[0][i].item() for i in range(5)}
    return labels