import gradio as gr | |
from transformers import pipeline | |
def image_classifier(image): | |
pipe = pipeline("image-classification", "bhargob11/vit-base-patch16-224-in21k-finetuned-housplants") | |
output = pipe(image) | |
best_prediction = max(output, key=lambda x: x['score']) | |
best_label = best_prediction['label'] | |
return best_label | |
gr.close_all() | |
demo = gr.Interface(fn=image_classifier, | |
inputs=[gr.Image(label="Upload image", type="pil")], | |
outputs=[gr.Textbox(label="Category")], | |
title="Image Classification with Fine-Tuned ViT Model", | |
description="Classify any houseplant images", | |
allow_flagging="never") | |
# examples=["christmas_dog.jpeg", "bird_flight.jpeg", "cow.jpeg"]) | |
demo.launch(share=True) # server_port=int(os.environ['PORT1']) | |