Spaces:
Running
Running
File size: 4,855 Bytes
a5316e5 7e0319c d563836 a5316e5 6176ef8 443a3b3 a5316e5 6176ef8 ccf126e d563836 c8ce48e d563836 6176ef8 ccf126e 6176ef8 ccf126e 6176ef8 fc3586c ccf126e 6176ef8 d563836 54a8fdb 1c2f25a a5316e5 6176ef8 16ad661 6176ef8 a7e54a7 6176ef8 c180063 a7e54a7 a5316e5 6176ef8 a5316e5 7e0319c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 |
import gradio as gr
from transformers import pipeline
from datasets import load_dataset
import requests
from bs4 import BeautifulSoup
def return_model(task):
if task == 'classification':
model = pipeline("text-classification", model="CesarLeblanc/test_model")
else:
model = pipeline("fill-mask", model="CesarLeblanc/fill_mask_model")
return model
def return_dataset():
dataset = load_dataset("CesarLeblanc/text_classification_dataset")
return dataset
def return_text(habitat_label, habitat_score, confidence):
if habitat_score*100 > confidence:
text = f"This vegetation plot belongs to the habitat {habitat_label} with the probability {habitat_score*100:.2f}%."
else:
text = f"We can't assign an habitat to this vegetation plot with a confidence of at least {confidence}%."
return text
def return_habitat_image(habitat_label, habitat_score, confidence):
floraveg_url = f"https://floraveg.eu/habitat/overview/{habitat_label}"
response = requests.get(floraveg_url)
if response.status_code == 200:
soup = BeautifulSoup(response.text, 'html.parser')
img_tag = soup.find('img', src=lambda x: x and x.startswith("https://files.ibot.cas.cz/cevs/images/syntaxa/thumbs/"))
if img_tag:
image_url = img_tag['src']
else:
image_url = "https://www.salonlfc.com/wp-content/uploads/2018/01/image-not-found-scaled-1150x647.png"
else:
image_url = "https://www.salonlfc.com/wp-content/uploads/2018/01/image-not-found-scaled-1150x647.png"
if habitat_score*100 < confidence:
image_url = "https://www.salonlfc.com/wp-content/uploads/2018/01/image-not-found-scaled-1150x647.png"
image = gr.Image(value=image_url)
return image
def return_species_image(species):
species = species[0].capitalize() + species[1:]
floraveg_url = f"https://floraveg.eu/taxon/overview/{species}"
response = requests.get(floraveg_url)
if response.status_code == 200:
soup = BeautifulSoup(response.text, 'html.parser')
img_tag = soup.find('img', src=lambda x: x and x.startswith("https://files.ibot.cas.cz/cevs/images/taxa/large/"))
if img_tag:
image_url = img_tag['src']
else:
image_url = "https://www.salonlfc.com/wp-content/uploads/2018/01/image-not-found-scaled-1150x647.png"
else:
image_url = "https://www.salonlfc.com/wp-content/uploads/2018/01/image-not-found-scaled-1150x647.png"
image = gr.Image(value=image_url)
return image
def classification(text, typology, confidence, task):
model = return_model(task)
dataset = return_dataset()
result = model(text)
habitat_label = result[0]['label']
habitat_label = dataset['train'].features['label'].names[int(habitat_label.split('_')[1])]
habitat_score = result[0]['score']
formatted_output = return_text(habitat_label, habitat_score, confidence)
image_output = return_habitat_image(habitat_label, habitat_score, confidence)
return formatted_output, image_output
def masking(text, task):
model = return_model(task)
masked_text = text + ', [MASK] [MASK]'
pred = model(masked_text, top_k=1)
new_species = [pred[i][0]['token_str'] for i in range(len(pred))]
new_species = ' '.join(new_species)
text = text + ', ' + new_species
image = return_species_image(new_species)
return text, image
def plantbert(text, typology, confidence, task):
if task == "classification":
formatted_output, image_output = classification(text, typology, confidence, task)
else:
formatted_output, image_output = masking(text, task)
return formatted_output, image_output
inputs=[
gr.Textbox(lines=2, label="Species", placeholder="Enter a list of comma-separated binomial names here."),
gr.Dropdown(["EUNIS"], value="EUNIS", label="Typology", info="Will add more typologies later!"),
gr.Slider(0, 100, value=90, label="Confidence", info="Choose the level of confidence for the prediction."),
gr.Radio(["classification", "masking"], value="classification", label="Task", info="Which task to choose?")
]
outputs=[
gr.Textbox(lines=2, label="Vegetation Plot Classification Result"),
"image"
]
title="Pl@ntBERT"
description="Vegetation Plot Classification: enter the species found in a vegetation plot and see its EUNIS habitat!"
examples=[
["sparganium erectum, calystegia sepium, persicaria amphibia", "EUNIS", 90, "classification"],
["vaccinium myrtillus, dryopteris dilatata, molinia caerulea", "EUNIS", 90, "masking"]
]
io = gr.Interface(fn=plantbert,
inputs=inputs,
outputs=outputs,
title=title,
description=description,
examples=examples)
io.launch() |