Spaces:
Running
Running
Last commit not found
import gradio as gr | |
import numpy as np | |
from PIL import Image | |
from transformers import AutoImageProcessor, AutoModelForImageClassification | |
import requests | |
labels = [ | |
"None", | |
"Circle", | |
"Triangle", | |
"Square", | |
"Pentagon", | |
"Hexagon" | |
] | |
feature_extractor = AutoImageProcessor.from_pretrained('0-ma/vit-geometric-shapes-tiny') | |
model = AutoModelForImageClassification.from_pretrained('0-ma/vit-geometric-shapes-tiny') | |
def predict(image): | |
inputs = feature_extractor(images=[image], return_tensors="pt") | |
logits = model(**inputs)['logits'].cpu().detach().numpy()[0] | |
logits_positive = logits | |
logits_positive[logits < 0] = 0 | |
logits_positive = logits_positive/np.sum(logits_positive) | |
confidences = {} | |
for i in range(len(labels)): | |
if logits[i]>0: | |
confidences[labels[i]] = float(logits_positive[i]) | |
return confidences | |
title = "Geometric Shape Classifier" | |
description = "The geometric shape classifier: 0-ma/vit-geometric-shapes-tiny." | |
examples = ['example/1_None.jpg','example/2_Circle.jpg','example/3_Triangle.jpg','example/4_Square.jpg','example/5_Pentagone.jpg','example/6_Hexagone.jpg'] | |
gr.Interface(fn=predict,inputs=gr.Image(type="pil"),outputs=gr.Label(),title=title,description=description,examples=examples).launch() | |