Spaces:
Running
Running
import gradio as gr | |
from transformers import pipeline | |
from datasets import load_dataset | |
import requests | |
from bs4 import BeautifulSoup | |
def return_model(task): | |
if task == 'classification': | |
model = pipeline("text-classification", model="CesarLeblanc/test_model") | |
else: | |
model = pipeline("fill-mask", model="CesarLeblanc/fill_mask_model") | |
return model | |
def return_dataset(): | |
dataset = load_dataset("CesarLeblanc/text_classification_dataset") | |
return dataset | |
def return_text(habitat_label, habitat_score, confidence): | |
if habitat_score*100 > confidence: | |
text = f"This vegetation plot belongs to the habitat {habitat_label} with the probability {habitat_score*100:.2f}%." | |
else: | |
text = f"We can't assign an habitat to this vegetation plot with a confidence of at least {confidence}%." | |
return text | |
def return_habitat_image(habitat_label, habitat_score, confidence): | |
floraveg_url = f"https://floraveg.eu/habitat/overview/{habitat_label}" | |
response = requests.get(floraveg_url) | |
if response.status_code == 200: | |
soup = BeautifulSoup(response.text, 'html.parser') | |
img_tag = soup.find('img', src=lambda x: x and x.startswith("https://files.ibot.cas.cz/cevs/images/syntaxa/thumbs/")) | |
if img_tag: | |
image_url = img_tag['src'] | |
else: | |
image_url = "https://www.salonlfc.com/wp-content/uploads/2018/01/image-not-found-scaled-1150x647.png" | |
else: | |
image_url = "https://www.salonlfc.com/wp-content/uploads/2018/01/image-not-found-scaled-1150x647.png" | |
if habitat_score*100 < confidence: | |
image_url = "https://www.salonlfc.com/wp-content/uploads/2018/01/image-not-found-scaled-1150x647.png" | |
image_url = "https://www.commissionoceanindien.org/wp-content/uploads/2018/07/plantnet.jpg" | |
image = gr.Image(value=image_url) | |
return image | |
def return_species_image(species): | |
species = species[0].capitalize() + species[1:] | |
floraveg_url = f"https://floraveg.eu/taxon/overview/{species}" | |
response = requests.get(floraveg_url) | |
if response.status_code == 200: | |
soup = BeautifulSoup(response.text, 'html.parser') | |
img_tag = soup.find('img', src=lambda x: x and x.startswith("https://files.ibot.cas.cz/cevs/images/taxa/large/")) | |
if img_tag: | |
image_url = img_tag['src'] | |
else: | |
image_url = "https://www.salonlfc.com/wp-content/uploads/2018/01/image-not-found-scaled-1150x647.png" | |
else: | |
image_url = "https://www.salonlfc.com/wp-content/uploads/2018/01/image-not-found-scaled-1150x647.png" | |
image_url = "https://www.commissionoceanindien.org/wp-content/uploads/2018/07/plantnet.jpg" | |
image = gr.Image(value=image_url) | |
return image | |
def classification(text, typology, confidence): | |
model = return_model("classification") | |
dataset = return_dataset() | |
result = model(text) | |
habitat_label = result[0]['label'] | |
habitat_label = dataset['train'].features['label'].names[int(habitat_label.split('_')[1])] | |
habitat_score = result[0]['score'] | |
formatted_output = return_text(habitat_label, habitat_score, confidence) | |
image_output = return_habitat_image(habitat_label, habitat_score, confidence) | |
return formatted_output, image_output | |
def masking(text): | |
model = return_model("masking") | |
masked_text = text + ', [MASK] [MASK]' | |
pred = model(masked_text, top_k=1) | |
new_species = [pred[i][0]['token_str'] for i in range(len(pred))] | |
new_species = ' '.join(new_species) | |
text = f"The last species from this vegetation plot is probably {new_species}." | |
image = return_species_image(new_species) | |
return text, image | |
with gr.Blocks() as demo: | |
gr.Markdown("""<h1 style="text-align: center;">Pl@ntBERT</h1>""") | |
with gr.Tab("Vegetation plot classification"): | |
gr.Markdown("""<h3 style="text-align: center;">Classification of vegetation plots!</h3>""") | |
with gr.Row(): | |
with gr.Column(): | |
species = gr.Textbox(lines=2, label="Species", placeholder="Enter a list of comma-separated binomial names here.") | |
typology = gr.Dropdown(["EUNIS"], value="EUNIS", label="Typology", info="Will add more typologies later!") | |
confidence = gr.Slider(0, 100, value=90, label="Confidence", info="Choose the level of confidence for the prediction.") | |
with gr.Column(): | |
text_output_1 = gr.Textbox() | |
text_output_2 = gr.Image() | |
text_button = gr.Button("Classify") | |
gr.Markdown("""<h5 style="text-align: center;">An example of input</h5>""") | |
gr.Examples(["sparganium erectum, calystegia sepium, persicaria amphibia", "EUNIS", 90]) | |
with gr.Tab("Missing species finding"): | |
gr.Markdown("""<h3 style="text-align: center;">Finding the missing species!</h3>""") | |
with gr.Row(): | |
species_2 = gr.Textbox(lines=2, label="Species", placeholder="Enter a list of comma-separated binomial names here.") | |
with gr.Column(): | |
image_output_1 = gr.Textbox() | |
image_output_2 = gr.Image() | |
image_button = gr.Button("Find") | |
gr.Markdown("""<h5 style="text-align: center;">An example of input</h5>""") | |
gr.Examples(["vaccinium myrtillus, dryopteris dilatata, molinia caerulea"]) | |
text_button.click(classification, inputs=[species, typology, confidence], outputs=[text_output_1, text_output_2]) | |
image_button.click(masking, inputs=[species_2], outputs=[image_output_1, image_output_2]) | |
demo.launch() |