File size: 1,399 Bytes
192d452
c7d5210
 
 
 
 
 
 
 
 
 
 
 
5dcea5c
 
 
7f73a40
 
c7d5210
192d452
e7f2b5e
f22056c
 
 
c7d5210
f22056c
c7d5210
f22056c
c7d5210
 
f22056c
192d452
 
 
99f5190
68dcd2b
192d452
e7f2b5e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import gradio as gr
import numpy as np
from PIL import Image
from transformers import AutoImageProcessor, AutoModelForImageClassification 
import requests
labels =  [
    "None",
    "Circle",
    "Triangle",
    "Square",
    "Pentagon",
    "Hexagon"
] 



feature_extractor = AutoImageProcessor.from_pretrained('0-ma/vit-geometric-shapes-tiny')
model = AutoModelForImageClassification.from_pretrained('0-ma/vit-geometric-shapes-tiny')

labels = []
def predict(image):
    feature_extractor = AutoImageProcessor.from_pretrained('0-ma/vit-geometric-shapes-tiny')
    model = AutoModelForImageClassification.from_pretrained('0-ma/vit-geometric-shapes-tiny')
    inputs = feature_extractor(images=[image], return_tensors="pt")
    logits = model(**inputs)['logits'].cpu().detach().numpy()
    predictions = np.argmax(logits, axis=1)    
    predicted_labels = [labels[prediction] for prediction in predictions]
    print(predicted_labels[0],logits[0][predictions[0]])


    return {"predicted_label" : predicted_labels[0] }
 
title = "Geometric Shape Classifier"
description = "A geometric shape setector."
examples = ['example/1_None.jpg','example/2_Circle.jpg','example/3_Triangle.jpg','example/4_Square.jpg','example/5_Pentagone.jpg','example/6_Hexagone.jpg']


gr.Interface(fn=predict,inputs=gr.Image(type="pil"),outputs=gr.Label(),title=title,description=description,examples=examples).launch()