plantbert_space / app.py
CesarLeblanc's picture
736419a
raw
history blame
7.51 kB
import gradio as gr
from transformers import pipeline
import requests
from bs4 import BeautifulSoup
import pandas as pd
# Initialize models
classification_model = pipeline("text-classification", model="models/text_classification_model", tokenizer="models/text_classification_model", top_k=5)
mask_model = pipeline("fill-mask", model="models/fill_mask_model", tokenizer="models/fill_mask_model", top_k=100)
# Load data
eunis_habitats = pd.read_excel('data/eunis_habitats.xlsx')
def return_habitat_image(habitat_label):
floraveg_url = f"https://floraveg.eu/habitat/overview/{habitat_label}"
response = requests.get(floraveg_url)
if response.status_code == 200:
soup = BeautifulSoup(response.text, 'html.parser')
img_tag = soup.find('img', src=lambda x: x and x.startswith("https://files.ibot.cas.cz/cevs/images/syntaxa/thumbs/"))
if img_tag:
image_url = img_tag['src']
else:
image_url = "https://www.salonlfc.com/wp-content/uploads/2018/01/image-not-found-scaled-1150x647.png"
else:
image_url = "https://www.salonlfc.com/wp-content/uploads/2018/01/image-not-found-scaled-1150x647.png"
image_url = "https://www.commissionoceanindien.org/wp-content/uploads/2018/07/plantnet.jpg" # While we don't have the rights
image = gr.Image(value=image_url)
return image
def return_species_image(species):
species = species.capitalize()
floraveg_url = f"https://floraveg.eu/taxon/overview/{species}"
response = requests.get(floraveg_url)
if response.status_code == 200:
soup = BeautifulSoup(response.text, 'html.parser')
img_tag = soup.find('img', src=lambda x: x and x.startswith("https://files.ibot.cas.cz/cevs/images/taxa/large/"))
if img_tag:
image_url = img_tag['src']
else:
image_url = "https://www.salonlfc.com/wp-content/uploads/2018/01/image-not-found-scaled-1150x647.png"
else:
image_url = "https://www.salonlfc.com/wp-content/uploads/2018/01/image-not-found-scaled-1150x647.png"
image_url = "https://www.commissionoceanindien.org/wp-content/uploads/2018/07/plantnet.jpg"
image = gr.Image(value=image_url)
return image
def gbif_normalization(text):
base = "https://api.gbif.org/v1"
api = "species"
function = "match"
parameter = "name"
url = f"{base}/{api}/{function}?{parameter}="
all_species = text.split(',')
all_species = [species.strip() for species in all_species]
species_gbif = []
for species in all_species:
url = url.replace(url.partition('name')[2], f'={species}')
r = requests.get(url)
r = r.json()
if 'species' in r:
r = r["species"]
else:
r = species
species_gbif.append(r)
text = ", ".join(species_gbif)
text = text.lower()
return text
def classification(text, k):
text = gbif_normalization(text)
result = classification_model(text)
habitat_labels = [res['label'] for res in result[0][:k]]
if k == 1:
text = f"This vegetation plot belongs to the habitat {habitat_labels[0]}."
else:
text = f"This vegetation plot belongs to the habitats {', '.join(habitat_labels[:-1])} and {habitat_labels[-1]}."
habitat_name = eunis_habitats[eunis_habitats['EUNIS 2020 code'] == habitat_labels[0]]['EUNIS-2021 habitat name'].values[0]
text += f"\nThe most likely habitat is '{habitat_name}'."
text += f"\n See image of the most likely habitat, {habitat_labels[0]}, below)."
image_output = return_habitat_image(habitat_labels[0])
return text, image_output
def masking(text, k):
text = gbif_normalization(text)
text_split = text.split(', ')
best_predictions = []
best_positions = []
for _ in range(k):
max_score = 0
best_prediction = None
best_position = None
best_sentence = None
for i in range(len(text_split) + 1):
masked_text = ', '.join(text_split[:i] + ['[MASK]'] + text_split[i:])
j = 0
while True:
prediction = mask_model(masked_text)[j]
species = prediction['token_str']
if species in text_split or species in best_predictions:
j += 1
else:
break
score = prediction['score']
sentence = prediction['sequence']
if score > max_score:
max_score = score
best_prediction = species
best_position = i
best_sentence = sentence
best_predictions.append(best_prediction)
best_positions.append(best_position)
text_split.insert(best_position, best_prediction)
if k == 1:
text = f"The most likely missing species is {best_predictions[0]} (position {best_positions[0]})."
else:
text = f"The most likely missing species are {', '.join(best_predictions[:-1])} and {best_predictions[-1]} (positions {', '.join(map(str, best_positions[:-1]))} and {best_positions[-1]})."
text += f"\nThe new vegetation plot is {best_sentence}."
text += f"\n See image of the most likely species, {best_predictions[0]}, below)."
image = return_species_image(best_prediction[0])
return text, image
with gr.Blocks() as demo:
gr.Markdown("""<h1 style="text-align: center;">Pl@ntBERT</h1>""")
with gr.Tab("Vegetation plot classification"):
gr.Markdown("""<h3 style="text-align: center;">Classification of vegetation plots!</h3>""")
with gr.Row():
with gr.Column():
species_classification = gr.Textbox(lines=2, label="Species", placeholder="Enter a list of comma-separated binomial names here.")
k_classification = gr.Slider(1, 5, value=1, step=1, label="Top-k", info="Choose the number of top habitats to display.")
with gr.Column():
text_classification = gr.Textbox()
image_classification = gr.Image()
button_classification = gr.Button("Classify")
gr.Markdown("""<h5 style="text-align: center;">An example of input</h5>""")
gr.Examples([["sparganium erectum, calystegia sepium, persicaria amphibia", 1]], [species_classification, k_classification], [text_classification, image_classification], classification, True)
with gr.Tab("Missing species finding"):
gr.Markdown("""<h3 style="text-align: center;">Finding the missing species!</h3>""")
with gr.Row():
species_masking = gr.Textbox(lines=2, label="Species", placeholder="Enter a list of comma-separated binomial names here.")
k_masking = gr.Slider(1, 5, value=1, step=1, label="Top-k", info="Choose the number of top missing species to find.")
with gr.Column():
text_masking = gr.Textbox()
image_masking = gr.Image()
button_masking = gr.Button("Find")
gr.Markdown("""<h5 style="text-align: center;">An example of input</h5>""")
gr.Examples([["vaccinium myrtillus, dryopteris dilatata, molinia caerulea", 1]], [species_masking, k_masking], [text_masking, image_masking], masking, True)
button_classification.click(classification, inputs=[species_classification, k_classification], outputs=[text_classification, image_classification])
button_masking.click(masking, inputs=[species_masking, k_masking], outputs=[text_masking, image_masking])
demo.launch()