Spaces:
Running
Running
import gradio as gr | |
from transformers import pipeline | |
from datasets import load_dataset | |
import requests | |
from bs4 import BeautifulSoup | |
def return_model(task): | |
if task == 'classification': | |
model = pipeline("text-classification", model="CesarLeblanc/test_model") | |
else: | |
model = pipeline("fill-mask", model="CesarLeblanc/fill_mask_model") | |
return model | |
def return_dataset(): | |
dataset = load_dataset("CesarLeblanc/text_classification_dataset") | |
return dataset | |
def return_text(habitat_label, habitat_score, confidence): | |
if habitat_score*100 > confidence: | |
text = f"This vegetation plot belongs to the habitat {habitat_label} with the probability {habitat_score*100:.2f}%." | |
else: | |
text = f"We can't assign an habitat to this vegetation plot with a confidence of at least {confidence}%." | |
return text | |
def return_habitat_image(habitat_label, habitat_score, confidence): | |
floraveg_url = f"https://floraveg.eu/habitat/overview/{habitat_label}" | |
response = requests.get(floraveg_url) | |
if response.status_code == 200: | |
soup = BeautifulSoup(response.text, 'html.parser') | |
img_tag = soup.find('img', src=lambda x: x and x.startswith("https://files.ibot.cas.cz/cevs/images/syntaxa/thumbs/")) | |
if img_tag: | |
image_url = img_tag['src'] | |
else: | |
image_url = "https://www.salonlfc.com/wp-content/uploads/2018/01/image-not-found-scaled-1150x647.png" | |
else: | |
image_url = "https://www.salonlfc.com/wp-content/uploads/2018/01/image-not-found-scaled-1150x647.png" | |
if habitat_score*100 < confidence: | |
image_url = "https://www.salonlfc.com/wp-content/uploads/2018/01/image-not-found-scaled-1150x647.png" | |
image = gr.Image(value=image_url) | |
return image | |
def return_species_image(species): | |
species = species[0].capitalize() + species[1:] | |
floraveg_url = f"https://floraveg.eu/taxon/overview/{species}" | |
response = requests.get(floraveg_url) | |
if response.status_code == 200: | |
soup = BeautifulSoup(response.text, 'html.parser') | |
img_tag = soup.find('img', src=lambda x: x and x.startswith("https://files.ibot.cas.cz/cevs/images/taxa/large/")) | |
if img_tag: | |
image_url = img_tag['src'] | |
else: | |
image_url = "https://www.salonlfc.com/wp-content/uploads/2018/01/image-not-found-scaled-1150x647.png" | |
else: | |
image_url = "https://www.salonlfc.com/wp-content/uploads/2018/01/image-not-found-scaled-1150x647.png" | |
image = gr.Image(value=image_url) | |
return image | |
def classification(text, typology, confidence, task): | |
model = return_model(task) | |
dataset = return_dataset() | |
result = model(text) | |
habitat_label = result[0]['label'] | |
habitat_label = dataset['train'].features['label'].names[int(habitat_label.split('_')[1])] | |
habitat_score = result[0]['score'] | |
formatted_output = return_text(habitat_label, habitat_score, confidence) | |
image_output = return_habitat_image(habitat_label, habitat_score, confidence) | |
return formatted_output, image_output | |
def masking(text, task): | |
model = return_model(task) | |
masked_text = text + ', [MASK] [MASK]' | |
pred = model(masked_text, top_k=1) | |
new_species = [pred[i][0]['token_str'] for i in range(len(pred))] | |
new_species = ' '.join(new_species) | |
text = text + ', ' + new_species | |
image = return_species_image(new_species) | |
return text, image | |
def plantbert(text, typology, confidence, task): | |
if task == "classification": | |
formatted_output, image_output = classification(text, typology, confidence, task) | |
else: | |
formatted_output, image_output = masking(text, task) | |
return formatted_output, image_output | |
inputs=[ | |
gr.Textbox(lines=2, label="Species", placeholder="Enter a list of comma-separated binomial names here."), | |
gr.Dropdown(["EUNIS"], value="EUNIS", label="Typology", info="Will add more typologies later!"), | |
gr.Slider(0, 100, value=90, label="Confidence", info="Choose the level of confidence for the prediction."), | |
gr.Radio(["classification", "masking"], value="classification", label="Task", info="Which task to choose?") | |
] | |
outputs=[ | |
gr.Textbox(lines=2, label="Vegetation Plot Classification Result"), | |
"image" | |
] | |
title="Pl@ntBERT" | |
description="Vegetation Plot Classification: enter the species found in a vegetation plot and see its EUNIS habitat!" | |
examples=[ | |
["sparganium erectum, calystegia sepium, persicaria amphibia", "EUNIS", 90, "classification"], | |
["vaccinium myrtillus, dryopteris dilatata, molinia caerulea", "EUNIS", 90, "masking"] | |
] | |
io = gr.Interface(fn=plantbert, | |
inputs=inputs, | |
outputs=outputs, | |
title=title, | |
description=description, | |
examples=examples) | |
io.launch() |