Spaces:
Runtime error
Runtime error
import gradio as gr | |
from transformers import pipeline | |
import requests | |
from bs4 import BeautifulSoup | |
import random | |
# Initialize models | |
classification_model = pipeline("text-classification", model="plantbert_text_classification_model", tokenizer="plantbert_text_classification_model") | |
mask_model = pipeline("fill-mask", model="plantbert_fill_mask_model", tokenizer="plantbert_fill_mask_model", top_k=100) | |
def return_habitat_image(habitat_label): | |
floraveg_url = f"https://floraveg.eu/habitat/overview/{habitat_label}" | |
response = requests.get(floraveg_url) | |
if response.status_code == 200: | |
soup = BeautifulSoup(response.text, 'html.parser') | |
img_tag = soup.find('img', src=lambda x: x and x.startswith("https://files.ibot.cas.cz/cevs/images/syntaxa/thumbs/")) | |
if img_tag: | |
image_url = img_tag['src'] | |
else: | |
image_url = "https://www.salonlfc.com/wp-content/uploads/2018/01/image-not-found-scaled-1150x647.png" | |
else: | |
image_url = "https://www.salonlfc.com/wp-content/uploads/2018/01/image-not-found-scaled-1150x647.png" | |
image_url = "https://www.commissionoceanindien.org/wp-content/uploads/2018/07/plantnet.jpg" # While we don't have the rights | |
image = gr.Image(value=image_url) | |
return image | |
def return_species_image(species): | |
species = species.capitalize() | |
floraveg_url = f"https://floraveg.eu/taxon/overview/{species}" | |
response = requests.get(floraveg_url) | |
if response.status_code == 200: | |
soup = BeautifulSoup(response.text, 'html.parser') | |
img_tag = soup.find('img', src=lambda x: x and x.startswith("https://files.ibot.cas.cz/cevs/images/taxa/large/")) | |
if img_tag: | |
image_url = img_tag['src'] | |
else: | |
image_url = "https://www.salonlfc.com/wp-content/uploads/2018/01/image-not-found-scaled-1150x647.png" | |
else: | |
image_url = "https://www.salonlfc.com/wp-content/uploads/2018/01/image-not-found-scaled-1150x647.png" | |
image_url = "https://www.commissionoceanindien.org/wp-content/uploads/2018/07/plantnet.jpg" | |
image = gr.Image(value=image_url) | |
return image | |
def gbif_normalization(text): | |
base = "https://api.gbif.org/v1" | |
api = "species" | |
function = "match" | |
parameter = "name" | |
url = f"{base}/{api}/{function}?{parameter}=" | |
all_species = text.split(',') | |
all_species = [species.strip() for species in all_species] | |
species_gbif = [] | |
for species in all_species: | |
url = url.replace(url.partition('name')[2], f'={species}') | |
r = requests.get(url) | |
r = r.json() | |
if 'species' in r: | |
r = r["species"] | |
else: | |
r = species | |
species_gbif.append(r) | |
text = ", ".join(species_gbif) | |
text = text.lower() | |
return text | |
def classification(text): | |
text = gbif_normalization(text) | |
result = classification_model(text) | |
habitat_label = result[0]['label'] | |
text = f"This vegetation plot belongs to the habitat {habitat_label}." | |
image_output = return_habitat_image(habitat_label) | |
return text, image_output | |
def masking(text): | |
text = gbif_normalization(text) | |
max_score = 0 | |
best_prediction = None | |
best_position = None | |
best_sentence = None | |
# Case for the first position | |
masked_text = "[MASK], " + ', '.join(text.split(', ')) | |
i = 0 | |
while True: | |
prediction = mask_model(masked_text)[i] | |
species = prediction['token_str'] | |
if species in text.split(', '): | |
i+=1 | |
else: | |
break | |
score = prediction['score'] | |
sentence = prediction['sequence'] | |
if score > max_score: | |
max_score = score | |
best_prediction = species | |
best_position = 0 | |
best_sentence = sentence | |
# Loop through each position in the middle of the sentence | |
for i in range(1, len(text.split(', '))): | |
masked_text = ', '.join(text.split(', ')[:i]) + ', [MASK], ' + ', '.join(text.split(', ')[i:]) | |
i = 0 | |
while True: | |
prediction = mask_model(masked_text)[i] | |
species = prediction['token_str'] | |
if species in text.split(', '): | |
i+=1 | |
else: | |
break | |
score = prediction['score'] | |
sentence = prediction['sequence'] | |
# Update best prediction and position if score is higher | |
if score > max_score: | |
max_score = score | |
best_prediction = species | |
best_position = i | |
best_sentence = sentence | |
# Case for the last position | |
masked_text = ', '.join(text.split(', ')) + ', [MASK]' | |
i = 0 | |
while True: | |
prediction = mask_model(masked_text)[i] | |
species = prediction['token_str'] | |
if species in text.split(', '): | |
i+=1 | |
else: | |
break | |
score = prediction['score'] | |
sentence = prediction['sequence'] | |
if score > max_score: | |
max_score = score | |
best_prediction = species | |
best_position = len(text.split(', ')) | |
best_sentence = sentence | |
text = f"The most likely missing species is {best_prediction} (position {best_position}).\nThe new vegetation plot is {best_sentence}." | |
image = return_species_image(best_prediction) | |
return text, image | |
with gr.Blocks() as demo: | |
gr.Markdown("""<h1 style="text-align: center;">Pl@ntBERT</h1>""") | |
with gr.Tab("Vegetation plot classification"): | |
gr.Markdown("""<h3 style="text-align: center;">Classification of vegetation plots!</h3>""") | |
with gr.Row(): | |
with gr.Column(): | |
species = gr.Textbox(lines=2, label="Species", placeholder="Enter a list of comma-separated binomial names here.") | |
top_k = | |
with gr.Column(): | |
text_output_1 = gr.Textbox() | |
text_output_2 = gr.Image() | |
text_button = gr.Button("Classify") | |
gr.Markdown("""<h5 style="text-align: center;">An example of input</h5>""") | |
gr.Examples([["sparganium erectum, calystegia sepium, persicaria amphibia"]], [species], [text_output_1, text_output_2], classification, True) | |
with gr.Tab("Missing species finding"): | |
gr.Markdown("""<h3 style="text-align: center;">Finding the missing species!</h3>""") | |
with gr.Row(): | |
species_2 = gr.Textbox(lines=2, label="Species", placeholder="Enter a list of comma-separated binomial names here.") | |
with gr.Column(): | |
image_output_1 = gr.Textbox() | |
image_output_2 = gr.Image() | |
image_button = gr.Button("Find") | |
gr.Markdown("""<h5 style="text-align: center;">An example of input</h5>""") | |
gr.Examples([["vaccinium myrtillus, dryopteris dilatata, molinia caerulea"]], [species_2], [image_output_1, image_output_2], masking, True) | |
text_button.click(classification, inputs=[species], outputs=[text_output_1, text_output_2]) | |
image_button.click(masking, inputs=[species_2], outputs=[image_output_1, image_output_2]) | |
demo.launch() |