Last commit not found
raw
history blame
1.3 kB
import gradio as gr
import numpy as np
from PIL import Image
from transformers import AutoImageProcessor, AutoModelForImageClassification
import requests
labels = [
"None",
"Circle",
"Triangle",
"Square",
"Pentagon",
"Hexagon"
]
feature_extractor = AutoImageProcessor.from_pretrained('0-ma/vit-geometric-shapes-tiny')
model = AutoModelForImageClassification.from_pretrained('0-ma/vit-geometric-shapes-tiny')
def predict(image):
inputs = feature_extractor(images=[image], return_tensors="pt")
logits = model(**inputs)['logits'].cpu().detach().numpy()[0]
logits_positive = logits
logits_positive[logits < 0] = 0
logits_positive = logits_positive/np.sum(logits_positive)
confidences = {}
for i in range(len(labels)):
if logits[i]>0:
confidences[labels[i]] = float(logits_positive[i])
return confidences
title = "Geometric Shape Classifier"
description = "The geometric shape classifier: 0-ma/vit-geometric-shapes-tiny."
examples = ['example/1_None.jpg','example/2_Circle.jpg','example/3_Triangle.jpg','example/4_Square.jpg','example/5_Pentagone.jpg','example/6_Hexagone.jpg']
gr.Interface(fn=predict,inputs=gr.Image(type="pil"),outputs=gr.Label(),title=title,description=description,examples=examples).launch()