plantbert_space / app.py
CesarLeblanc's picture
Update app.py
7e0319c
raw
history blame
1.23 kB
import gradio as gr
from transformers import pipeline
from datasets import load_dataset
classifier = pipeline("text-classification", model="CesarLeblanc/test_model")
dataset = load_dataset("CesarLeblanc/text_classification_dataset")
def text_classification(text):
result = classifier(text)
habitat_label = result[0]['label']
habitat_label = dataset['train'].features['label'].names[int(habitat_label.split('_')[1])]
habitat_score = result[0]['score']
formatted_output = f"This vegetation plot is {habitat_label} with the probability {habitat_score*100:.2f}%"
return formatted_output
examples=["quercus robur, betula pendula, holcus lanatus, lonicera periclymenum, carex arenaria, poa trivialis", "thinopyrum junceum, cakile maritima"]
io = gr.Interface(fn=text_classification,
inputs= gr.Textbox(lines=2, label="Text", placeholder="Enter species here..."),
outputs=gr.Textbox(lines=2, label="Text Classification Result"),
title="Vegetation Plot Classification",
description="Enter the species and see the vegetation plot classification result!",
examples=examples)
io.launch()