0-ma's picture
Update app.py
479feb3 verified
raw
history blame
2.75 kB
import gradio as gr
import numpy as np
from PIL import Image
from transformers import AutoImageProcessor, AutoModelForImageClassification
import requests
labels = [
"None",
"Circle",
"Triangle",
"Square",
"Pentagon",
"Hexagon"
]
# Available models for the dropdown
models = {
"0-ma/swin-geometric-shapes-tiny": "0-ma/swin-geometric-shapes-tiny",
"0-ma/mobilenet-v2-geometric-shapes": "0-ma/mobilenet-v2-geometric-shapes",
"0-ma/focalnet-geometric-shapes-tiny": "0-ma/focalnet-geometric-shapes-tiny" ,
"0-ma/efficientnet-b2-geometric-shapes":"0-ma/efficientnet-b2-geometric-shapes",
"0-ma/beit-geometric-shapes-base":"0-ma/beit-geometric-shapes-base",
"0-ma/mit-b0-geometric-shapes":"0-ma/mit-b0-geometric-shapes",
"0-ma/vit-geometric-shapes-base":"0-ma/vit-geometric-shapes-base",
"0-ma/resnet-geometric-shapes":"0-ma/resnet-geometric-shapes",
"0-ma/vit-geometric-shapes-tiny":"0-ma/vit-geometric-shapes-tiny",
}
# Load the default model
#feature_extractor = AutoImageProcessor.from_pretrained(models["Tiny Model"])
#model = AutoModelForImageClassification.from_pretrained(models["Tiny Model"])
feature_extractors = { model_name : AutoImageProcessor.from_pretrained(models[model_name]) for model_name in models}
classification_models = { model_name : AutoModelForImageClassification.from_pretrained(models[model_name]) for model_name in models}
def predict(image, selected_model):
# Load the selected model
# feature_extractor = AutoImageProcessor.from_pretrained(models[selected_model])
# model = AutoModelForImageClassification.from_pretrained(models[selected_model])
feature_extractor = feature_extractors[selected_model]
model = classification_models[selected_model]
inputs = feature_extractor(images=[image], return_tensors="pt")
logits = model(**inputs)['logits'].cpu().detach().numpy()[0]
logits_positive = logits
logits_positive[logits < 0] = 0
logits_positive = logits_positive/np.sum(logits_positive)
confidences = {}
for i in range(len(labels)):
if logits[i] > 0:
confidences[labels[i]] = float(logits_positive[i])
return confidences
title = "Geometric Shape Classifier"
description = "Select a model to classify geometric shapes."
examples = [
'example/1_None.jpg',
'example/2_Circle.jpg',
'example/3_Triangle.jpg',
'example/4_Square.jpg',
'example/5_Pentagone.jpg',
'example/6_Hexagone.jpg'
]
# Adding a dropdown for model selection
gr.Interface(
fn=predict,
inputs=[gr.Image(type="pil"), gr.Dropdown(list(models.keys()), label="Select Model")],
outputs=gr.Label(),
title=title,
description=description,
examples=examples
).launch()