File size: 1,882 Bytes
192d452
c7d5210
 
 
 
 
 
 
 
 
 
 
 
 
 
7f73a40
 
c7d5210
 
 
 
 
192d452
 
 
f22056c
c7d5210
f22056c
 
 
c7d5210
f22056c
c7d5210
f22056c
c7d5210
 
f22056c
192d452
 
 
99f5190
192d452
 
 
99f5190
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import gradio as gr
import numpy as np
from PIL import Image
from transformers import AutoImageProcessor, AutoModelForImageClassification 
import requests
labels =  [
    "None",
    "Circle",
    "Triangle",
    "Square",
    "Pentagon",
    "Hexagon"
] 
#images = [Image.open(requests.get("https://raw.githubusercontent.com/0-ma/geometric-shape-detector/main/input/exemple_circle.jpg", stream=True).raw), 
#        Image.open(requests.get("https://raw.githubusercontent.com/0-ma/geometric-shape-detector/main/input/exemple_pentagone.jpg", stream=True).raw)]
feature_extractor = AutoImageProcessor.from_pretrained('0-ma/vit-geometric-shapes-tiny')
model = AutoModelForImageClassification.from_pretrained('0-ma/vit-geometric-shapes-tiny')

   
print(predicted_labels)



labels = []
def predict(img):
    image = PILImage.create(img)

    feature_extractor = AutoImageProcessor.from_pretrained('0-ma/vit-geometric-shapes-tiny')
    model = AutoModelForImageClassification.from_pretrained('0-ma/vit-geometric-shapes-tiny')
    inputs = feature_extractor(images=[image], return_tensors="pt")
    logits = model(**inputs)['logits'].cpu().detach().numpy()
    predictions = np.argmax(logits, axis=1)    
    predicted_labels = [labels[prediction] for prediction in predictions]
    print(predicted_labels[0],logits[0][predictions[0]])


    return {"predicted_label" : predicted_labels[0] }
 
title = "Geometric Shape Classifier"
description = "A geometric shape setector."
examples = ['example/1_None.jpg','example/2_Circle.jpg','example/3_Triangle.jpg','example/4_Square.jpg','example/5_Pentagone.jpg','example/6_Hexagone.jpg']
interpretation='default'
enable_queue=True

gr.Interface(fn=predict,inputs=gr.inputs.Image(shape=(512, 512)),outputs=gr.outputs.Label(),title=title,description=description,examples=examples,interpretation=interpretation,enable_queue=enable_queue).launch()