plantbert_space / app.py
CesarLeblanc's picture
Update app.py
3c63477 verified
raw
history blame
7.23 kB
import gradio as gr
from transformers import pipeline
import requests
from bs4 import BeautifulSoup
import random
classification_model = pipeline("text-classification", model="plantbert_text_classification_model", tokenizer="plantbert_text_classification_model")
mask_model = pipeline("fill-mask", model="plantbert_fill_mask_model", tokenizer="plantbert_fill_mask_model")
def return_text(habitat_label, habitat_score, confidence):
if habitat_score*100 > confidence:
text = f"This vegetation plot belongs to the habitat {habitat_label} with the probability {habitat_score*100:.2f}%."
else:
text = f"We can't assign an habitat to this vegetation plot with a confidence of at least {confidence}%."
return text
def return_habitat_image(habitat_label, habitat_score, confidence):
floraveg_url = f"https://floraveg.eu/habitat/overview/{habitat_label}"
response = requests.get(floraveg_url)
if response.status_code == 200:
soup = BeautifulSoup(response.text, 'html.parser')
img_tag = soup.find('img', src=lambda x: x and x.startswith("https://files.ibot.cas.cz/cevs/images/syntaxa/thumbs/"))
if img_tag:
image_url = img_tag['src']
else:
image_url = "https://www.salonlfc.com/wp-content/uploads/2018/01/image-not-found-scaled-1150x647.png"
else:
image_url = "https://www.salonlfc.com/wp-content/uploads/2018/01/image-not-found-scaled-1150x647.png"
if habitat_score*100 < confidence:
image_url = "https://www.salonlfc.com/wp-content/uploads/2018/01/image-not-found-scaled-1150x647.png"
image_url = "https://www.commissionoceanindien.org/wp-content/uploads/2018/07/plantnet.jpg"
image = gr.Image(value=image_url)
return image
def return_species_image(species):
species = species.capitalize()
floraveg_url = f"https://floraveg.eu/taxon/overview/{species}"
response = requests.get(floraveg_url)
if response.status_code == 200:
soup = BeautifulSoup(response.text, 'html.parser')
img_tag = soup.find('img', src=lambda x: x and x.startswith("https://files.ibot.cas.cz/cevs/images/taxa/large/"))
if img_tag:
image_url = img_tag['src']
else:
image_url = "https://www.salonlfc.com/wp-content/uploads/2018/01/image-not-found-scaled-1150x647.png"
else:
image_url = "https://www.salonlfc.com/wp-content/uploads/2018/01/image-not-found-scaled-1150x647.png"
image_url = "https://www.commissionoceanindien.org/wp-content/uploads/2018/07/plantnet.jpg"
image = gr.Image(value=image_url)
return image
def gbif_normalization(text):
base = "https://api.gbif.org/v1"
api = "species"
function = "match"
parameter = "name"
url = f"{base}/{api}/{function}?{parameter}="
all_species = text.split(',')
all_species = [species.strip() for species in all_species]
species_gbif = []
for species in all_species:
url = url.replace(url.partition('name')[2], f'={species}')
r = requests.get(url)
r = r.json()
if 'species' in r:
r = r["species"]
else:
r = species
species_gbif.append(r)
text = ", ".join(species_gbif)
text = text.lower()
return text
def classification(text, typology, confidence):
text = gbif_normalization(text)
result = classification_model(text)
habitat_label = result[0]['label']
habitat_score = result[0]['score']
formatted_output = return_text(habitat_label, habitat_score, confidence)
image_output = return_habitat_image(habitat_label, habitat_score, confidence)
return formatted_output, image_output
def masking(text):
text = gbif_normalization(text)
max_score = 0
best_prediction = None
best_position = None
# Case for the first position
masked_text = "[MASK], " + ', '.join(text.split(', '))
prediction = mask_model(masked_text)[0]
species = prediction['token_str']
score = prediction['score']
if score > max_score:
max_score = score
best_prediction = species
best_position = 0
# Loop through each position in the middle of the sentence
for i in range(1, len(text.split(', '))):
masked_text = ', '.join(text.split(', ')[:i]) + ', [MASK], ' + ', '.join(text.split(', ')[i:])
prediction = mask_model(masked_text)[0]
species = prediction['token_str']
score = prediction['score']
# Update best prediction and position if score is higher
if score > max_score:
max_score = score
best_prediction = species
best_position = i
# Case for the last position
masked_text = ', '.join(text.split(', ')) + ', [MASK]'
prediction = mask_model(masked_text)[0]
species = prediction['token_str']
score = prediction['score']
if score > max_score:
max_score = score
best_prediction = species
best_position = len(text.split(', '))
text = f"The most likely missing species in position {best_position} is: {best_prediction}."
image = return_species_image(best_prediction)
return text, image
with gr.Blocks() as demo:
gr.Markdown("""<h1 style="text-align: center;">Pl@ntBERT</h1>""")
with gr.Tab("Vegetation plot classification"):
gr.Markdown("""<h3 style="text-align: center;">Classification of vegetation plots!</h3>""")
with gr.Row():
with gr.Column():
species = gr.Textbox(lines=2, label="Species", placeholder="Enter a list of comma-separated binomial names here.")
typology = gr.Dropdown(["EUNIS"], value="EUNIS", label="Typology", info="Will add more typologies later!")
confidence = gr.Slider(0, 100, value=90, label="Confidence", info="Choose the level of confidence for the prediction.")
with gr.Column():
text_output_1 = gr.Textbox()
text_output_2 = gr.Image()
text_button = gr.Button("Classify")
gr.Markdown("""<h5 style="text-align: center;">An example of input</h5>""")
gr.Examples([["sparganium erectum, calystegia sepium, persicaria amphibia", "EUNIS", 90]], [species, typology, confidence], [text_output_1, text_output_2], classification, True)
with gr.Tab("Missing species finding"):
gr.Markdown("""<h3 style="text-align: center;">Finding the missing species!</h3>""")
with gr.Row():
species_2 = gr.Textbox(lines=2, label="Species", placeholder="Enter a list of comma-separated binomial names here.")
with gr.Column():
image_output_1 = gr.Textbox()
image_output_2 = gr.Image()
image_button = gr.Button("Find")
gr.Markdown("""<h5 style="text-align: center;">An example of input</h5>""")
gr.Examples([["vaccinium myrtillus, dryopteris dilatata, molinia caerulea"]], [species_2], [image_output_1, image_output_2], masking, True)
text_button.click(classification, inputs=[species, typology, confidence], outputs=[text_output_1, text_output_2])
image_button.click(masking, inputs=[species_2], outputs=[image_output_1, image_output_2])
demo.launch()