import gradio as gr from transformers import pipeline from datasets import load_dataset import requests from bs4 import BeautifulSoup def return_model(task): if task == 'classification': model = pipeline("text-classification", model="CesarLeblanc/test_model") else: model = pipeline("fill-mask", model="CesarLeblanc/fill_mask_model") return model def return_dataset(): dataset = load_dataset("CesarLeblanc/text_classification_dataset") return dataset def return_text(habitat_label, habitat_score, confidence): if habitat_score*100 > confidence: text = f"This vegetation plot belongs to the habitat {habitat_label} with the probability {habitat_score*100:.2f}%." else: text = f"We can't assign an habitat to this vegetation plot with a confidence of at least {confidence}%." return text def return_habitat_image(habitat_label, habitat_score, confidence): floraveg_url = f"https://floraveg.eu/habitat/overview/{habitat_label}" response = requests.get(floraveg_url) if response.status_code == 200: soup = BeautifulSoup(response.text, 'html.parser') img_tag = soup.find('img', src=lambda x: x and x.startswith("https://files.ibot.cas.cz/cevs/images/syntaxa/thumbs/")) if img_tag: image_url = img_tag['src'] else: image_url = "https://www.salonlfc.com/wp-content/uploads/2018/01/image-not-found-scaled-1150x647.png" else: image_url = "https://www.salonlfc.com/wp-content/uploads/2018/01/image-not-found-scaled-1150x647.png" if habitat_score*100 < confidence: image_url = "https://www.salonlfc.com/wp-content/uploads/2018/01/image-not-found-scaled-1150x647.png" image_url = "https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcQoQPZxckzsiQyFh9w7Z7aJ38d23lvLQFj4QemMFjw2lvc18iQrzDYf7EzzmD7cFdfbsZU&usqp=CAU" image = gr.Image(value=image_url) return image def return_species_image(species): species = species[0].capitalize() + species[1:] floraveg_url = f"https://floraveg.eu/taxon/overview/{species}" response = requests.get(floraveg_url) if response.status_code == 200: soup = BeautifulSoup(response.text, 'html.parser') img_tag = soup.find('img', src=lambda x: x and x.startswith("https://files.ibot.cas.cz/cevs/images/taxa/large/")) if img_tag: image_url = img_tag['src'] else: image_url = "https://www.salonlfc.com/wp-content/uploads/2018/01/image-not-found-scaled-1150x647.png" else: image_url = "https://www.salonlfc.com/wp-content/uploads/2018/01/image-not-found-scaled-1150x647.png" image_url = "https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcQoQPZxckzsiQyFh9w7Z7aJ38d23lvLQFj4QemMFjw2lvc18iQrzDYf7EzzmD7cFdfbsZU&usqp=CAU" image = gr.Image(value=image_url) return image def classification(text, typology, confidence): model = return_model("classification") dataset = return_dataset() result = model(text) habitat_label = result[0]['label'] habitat_label = dataset['train'].features['label'].names[int(habitat_label.split('_')[1])] habitat_score = result[0]['score'] formatted_output = return_text(habitat_label, habitat_score, confidence) image_output = return_habitat_image(habitat_label, habitat_score, confidence) return formatted_output, image_output def masking(text): model = return_model("masking") masked_text = text + ', [MASK] [MASK]' pred = model(masked_text, top_k=1) new_species = [pred[i][0]['token_str'] for i in range(len(pred))] new_species = ' '.join(new_species) text = f"The last species from this vegetation plot is probably {new_species}." image = return_species_image(new_species) return text, image with gr.Blocks() as demo: gr.Markdown("""# Pl@ntBERT""") with gr.Tab("Vegetation plot classification"): gr.Markdown("""Classify vegetation plots!""") with gr.Row(): with gr.Column(): species = gr.Textbox(lines=2, label="Species", placeholder="Enter a list of comma-separated binomial names here.") typology = gr.Dropdown(["EUNIS"], value="EUNIS", label="Typology", info="Will add more typologies later!") confidence = gr.Slider(0, 100, value=90, label="Confidence", info="Choose the level of confidence for the prediction.") with gr.Column(): text_output_1 = gr.Textbox() text_output_2 = gr.Image() text_button = gr.Button("Classify") with gr.Tab("Missing species finding"): gr.Markdown("""Find the missing species!""") with gr.Row(): species_2 = gr.Textbox(lines=2, label="Species", placeholder="Enter a list of comma-separated binomial names here.") with gr.Column(): image_output_1 = gr.Textbox() image_output_2 = gr.Image() image_button = gr.Button("Find") text_button.click(classification, inputs=[species, typology, confidence], outputs=[text_output_1, text_output_2]) image_button.click(masking, inputs=[species_2], outputs=[image_output_1, image_output_2]) demo.launch()