import gradio as gr from transformers import pipeline import requests from bs4 import BeautifulSoup import pandas as pd # Initialize models classification_model = pipeline("text-classification", model="models/text_classification_model", tokenizer="models/text_classification_model", top_k=5) mask_model = pipeline("fill-mask", model="models/fill_mask_model", tokenizer="models/fill_mask_model", top_k=100) # Load data eunis_habitats = pd.read_excel('data/eunis_habitats.xlsx') def return_habitat_image(habitat_label): floraveg_url = f"https://floraveg.eu/habitat/overview/{habitat_label}" response = requests.get(floraveg_url) if response.status_code == 200: soup = BeautifulSoup(response.text, 'html.parser') img_tag = soup.find('img', src=lambda x: x and x.startswith("https://files.ibot.cas.cz/cevs/images/syntaxa/thumbs/")) if img_tag: image_url = img_tag['src'] else: image_url = "https://www.salonlfc.com/wp-content/uploads/2018/01/image-not-found-scaled-1150x647.png" else: image_url = "https://www.salonlfc.com/wp-content/uploads/2018/01/image-not-found-scaled-1150x647.png" #image_url = "https://www.commissionoceanindien.org/wp-content/uploads/2018/07/plantnet.jpg" # While we don't have the rights image = gr.Image(value=image_url) return image def return_species_image(species): species = species.capitalize() floraveg_url = f"https://floraveg.eu/taxon/overview/{species}" response = requests.get(floraveg_url) if response.status_code == 200: soup = BeautifulSoup(response.text, 'html.parser') img_tag = soup.find('img', src=lambda x: x and x.startswith("https://files.ibot.cas.cz/cevs/images/taxa/large/")) if img_tag: image_url = img_tag['src'] else: image_url = "https://www.salonlfc.com/wp-content/uploads/2018/01/image-not-found-scaled-1150x647.png" else: image_url = "https://www.salonlfc.com/wp-content/uploads/2018/01/image-not-found-scaled-1150x647.png" #image_url = "https://www.commissionoceanindien.org/wp-content/uploads/2018/07/plantnet.jpg" # While we don't have the rights image = gr.Image(value=image_url) return image def gbif_normalization(text): base = "https://api.gbif.org/v1" api = "species" function = "match" parameter = "name" url = f"{base}/{api}/{function}?{parameter}=" all_species = text.split(',') all_species = [species.strip() for species in all_species] species_gbif = [] for species in all_species: url = url.replace(url.partition('name')[2], f'={species}') r = requests.get(url) r = r.json() if 'species' in r: r = r["species"] else: r = species species_gbif.append(r) text = ", ".join(species_gbif) text = text.lower() return text def classification(text, k): text = gbif_normalization(text) result = classification_model(text) habitat_labels = [res['label'] for res in result[0][:k]] if k == 1: text = f"This vegetation plot belongs to the habitat {habitat_labels[0]}." else: text = f"This vegetation plot belongs to the habitats {', '.join(habitat_labels[:-1])} and {habitat_labels[-1]}." habitat_name = eunis_habitats[eunis_habitats['EUNIS 2020 code'] == habitat_labels[0]]['EUNIS-2021 habitat name'].values[0] text += f"\nThe most likely habitat is '{habitat_name}'." text += f"\nSee an image of this habitat (i.e., {habitat_labels[0]}) below." image_output = return_habitat_image(habitat_labels[0]) return text, image_output def masking(text, k): text = gbif_normalization(text) text_split = text.split(', ') best_predictions = [] best_positions = [] for _ in range(k): max_score = 0 best_prediction = None best_position = None best_sentence = None for i in range(len(text_split) + 1): masked_text = ', '.join(text_split[:i] + ['[MASK]'] + text_split[i:]) j = 0 while True: prediction = mask_model(masked_text)[j] species = prediction['token_str'] if species in text_split or species in best_predictions: j += 1 else: break score = prediction['score'] sentence = prediction['sequence'] if score > max_score: max_score = score best_prediction = species best_position = i best_sentence = sentence best_predictions.append(best_prediction) best_positions.append(best_position) text_split.insert(best_position, best_prediction) if k == 1: text = f"The most likely missing species is {best_predictions[0]} (position {best_positions[0]})." else: text = f"The most likely missing species are {', '.join(best_predictions[:-1])} and {best_predictions[-1]} (positions {', '.join(map(str, best_positions[:-1]))} and {best_positions[-1]})." text += f"\nThe new vegetation plot is {best_sentence}." text += f"\nSee an image of the most likely species (i.e., {best_predictions[0]}) below." image = return_species_image(best_prediction[0]) return text, image with gr.Blocks() as demo: gr.Markdown("""

Pl@ntBERT

""") with gr.Tab("Vegetation plot classification"): gr.Markdown("""

Classification of vegetation plots!

""") with gr.Row(): with gr.Column(): species_classification = gr.Textbox(lines=2, label="Species", placeholder="Enter a list of comma-separated binomial names here.") k_classification = gr.Slider(1, 5, value=1, step=1, label="Top-k", info="Choose the number of top habitats to display.") with gr.Column(): text_classification = gr.Textbox(label="Prediction") image_classification = gr.Image() button_classification = gr.Button("Classify") gr.Markdown("""
An example of input
""") gr.Examples([["sparganium erectum, calystegia sepium, persicaria amphibia", 1]], [species_classification, k_classification], [text_classification, image_classification], classification, True) with gr.Tab("Missing species finding"): gr.Markdown("""

Finding the missing species!

""") with gr.Row(): with gr.Column(): species_masking = gr.Textbox(lines=2, label="Species", placeholder="Enter a list of comma-separated binomial names here.") k_masking = gr.Slider(1, 5, value=1, step=1, label="Top-k", info="Choose the number of top missing species to find.") with gr.Column(): text_masking = gr.Textbox(label="Prediction") image_masking = gr.Image() button_masking = gr.Button("Find") gr.Markdown("""
An example of input
""") gr.Examples([["vaccinium myrtillus, dryopteris dilatata, molinia caerulea", 1]], [species_masking, k_masking], [text_masking, image_masking], masking, True) button_classification.click(classification, inputs=[species_classification, k_classification], outputs=[text_classification, image_classification]) button_masking.click(masking, inputs=[species_masking, k_masking], outputs=[text_masking, image_masking]) demo.launch()