Spaces:
Runtime error
Runtime error
File size: 4,011 Bytes
a5316e5 7e0319c d563836 a5316e5 6176ef8 a5316e5 6176ef8 d563836 c8ce48e d563836 6176ef8 d563836 6176ef8 1c2f25a a5316e5 6176ef8 16ad661 6176ef8 a7e54a7 6176ef8 a7e54a7 a5316e5 6176ef8 a5316e5 7e0319c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 |
import gradio as gr
from transformers import pipeline
from datasets import load_dataset
import requests
from bs4 import BeautifulSoup
def return_model(task):
if task == 'classification':
model = pipeline("text-classification", model="CesarLeblanc/test_model")
else:
model = pipeline("fill-mask", model="CesarLeblanc/fill_mask_model")
return return_model
def return_dataset():
dataset = load_dataset("CesarLeblanc/text_classification_dataset")
return dataset
def return_text(habitat_label, habitat_score, confidence):
if habitat_score*100 > confidence:
text = f"This vegetation plot belongs to the habitat {habitat_label} with the probability {habitat_score*100:.2f}%."
else:
text = f"We can't assign an habitat to this vegetation plot with a confidence of at least {confidence}%."
return text
def return_image(habitat_label, habitat_score, confidence):
floraveg_url = f"https://floraveg.eu/habitat/overview/{habitat_label}"
response = requests.get(floraveg_url)
if response.status_code == 200:
soup = BeautifulSoup(response.text, 'html.parser')
img_tag = soup.find('img', src=lambda x: x and x.startswith("https://files.ibot.cas.cz/cevs/images/syntaxa/thumbs/"))
if img_tag:
image_url = img_tag['src']
else:
image_url = "https://www.salonlfc.com/wp-content/uploads/2018/01/image-not-found-scaled-1150x647.png"
else:
image_url = "https://www.salonlfc.com/wp-content/uploads/2018/01/image-not-found-scaled-1150x647.png"
if habitat_score*100 < confidence:
image_url = "https://www.salonlfc.com/wp-content/uploads/2018/01/image-not-found-scaled-1150x647.png"
image = gr.Image(value=image_url)
return image
def classification(text, typology, confidence, task):
model = return_model(task)
dataset = return_dataset()
result = model(text)
habitat_label = result[0]['label']
habitat_label = dataset['train'].features['label'].names[int(habitat_label.split('_')[1])]
habitat_score = result[0]['score']
formatted_output = return_text(habitat_label, habitat_score, confidence)
image_output = return_image(habitat_label, habitat_score, confidence)
return formatted_output, image_output
def masking(text, task):
model = return_model(task)
text += ', [MASK] [MASK]'
pred = mask_filler(text, top_k=1)
text = pred[0]["sequence"]
image = gr.Image(value="https://www.salonlfc.com/wp-content/uploads/2018/01/image-not-found-scaled-1150x647.png")
return text, image
def plantbert(text, typology, confidence, task):
if task == "classification":
formatted_output, image_output = classification(text, typology, confidence, task)
else:
formatted_output, image_output = masking(text, typology, confidence, task)
return formatted_output, image_output
inputs=[
gr.Textbox(lines=2, label="Species", placeholder="Enter a list of comma-separated binomial names here."),
gr.Dropdown(["EUNIS"], value="EUNIS", label="Typology", info="Will add more typologies later!"),
gr.Slider(0, 100, value=90, label="Confidence", info="Choose the level of confidence for the prediction."),
gr.Radio(["classification", "masking"], value="classification", label="Task", info="Which task to choose?")
]
outputs=[
gr.Textbox(lines=2, label="Vegetation Plot Classification Result"),
"image"
]
title="Pl@ntBERT"
description="Vegetation Plot Classification: enter the species found in a vegetation plot and see its EUNIS habitat!"
examples=[
["sparganium erectum, calystegia sepium, persicaria amphibia", "EUNIS", 90, "classification"],
["thinopyrum junceum, cakile maritima", "EUNIS", 90, "masking"]
]
io = gr.Interface(fn=plantbert,
inputs=inputs,
outputs=outputs,
title=title,
description=description,
examples=examples)
io.launch() |