import gradio as gr from transformers import pipeline from datasets import load_dataset import requests from bs4 import BeautifulSoup classifier = pipeline("text-classification", model="CesarLeblanc/test_model") dataset = load_dataset("CesarLeblanc/text_classification_dataset") def text_classification(text, typology, confidence): result = classifier(text) habitat_label = result[0]['label'] habitat_label = dataset['train'].features['label'].names[int(habitat_label.split('_')[1])] habitat_score = result[0]['score'] formatted_output = f"This vegetation plot belongs to the habitat {habitat_label} with the probability {habitat_score*100:.2f}%" floraveg_url = f"https://floraveg.eu/habitat/overview/{habitat_label}" response = requests.get(floraveg_url) if response.status_code == 200: soup = BeautifulSoup(response.text, 'html.parser') img_tag = soup.find('img', src=lambda x: x and x.startswith("https://files.ibot.cas.cz/cevs/images/syntaxa/thumbs/")) if img_tag: image_url = img_tag['src'] else: image_url = 'https://www.salonlfc.com/wp-content/uploads/2018/01/image-not-found-scaled-1150x647.png' else: image_url = 'https://www.salonlfc.com/wp-content/uploads/2018/01/image-not-found-scaled-1150x647.png' image_output = gr.Image(value=image_url) return formatted_output, image_output examples=[ ["sparganium erectum, calystegia sepium, persicaria amphibia", "EUNIS", 50], ["thinopyrum junceum, cakile maritima", "EUNIS", 50] ] io = gr.Interface(fn=text_classification, inputs=[gr.Textbox(lines=2, label="List of comma-separated binomial names of species (see examples)", placeholder="Enter species here..."), gr.Dropdown(["EUNIS"], label="Typology", info="Will add more typologies later!"), gr.Slider(0, 100, value=50, label="Confidence", info="Choose the level of confidence for the prediction")], outputs=[gr.Textbox(lines=2, label="Vegetation Plot Classification Result"), "image"], title="Pl@ntBERT", description="Vegetation Plot Classification: enter the species found in a vegetation plot and see its EUNIS habitat!", examples=examples) io.launch()