File size: 4,082 Bytes
a5316e5
 
7e0319c
d563836
 
a5316e5
6176ef8
 
 
 
 
443a3b3
a5316e5
6176ef8
 
 
 
 
 
 
 
 
 
 
 
d563836
c8ce48e
d563836
 
 
 
 
 
6176ef8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fc3586c
 
 
 
6176ef8
 
 
 
 
 
d563836
54a8fdb
1c2f25a
a5316e5
6176ef8
 
 
16ad661
 
6176ef8
 
 
 
 
 
 
 
 
 
 
a7e54a7
6176ef8
 
a7e54a7
a5316e5
6176ef8
 
 
 
 
a5316e5
 
7e0319c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import gradio as gr
from transformers import pipeline
from datasets import load_dataset
import requests
from bs4 import BeautifulSoup

def return_model(task):
    if task == 'classification':
        model = pipeline("text-classification", model="CesarLeblanc/test_model")
    else:
        model = pipeline("fill-mask", model="CesarLeblanc/fill_mask_model")
    return model

def return_dataset():
    dataset = load_dataset("CesarLeblanc/text_classification_dataset")
    return dataset

def return_text(habitat_label, habitat_score, confidence):
    if habitat_score*100 > confidence:
        text = f"This vegetation plot belongs to the habitat {habitat_label} with the probability {habitat_score*100:.2f}%."
    else:
        text = f"We can't assign an habitat to this vegetation plot with a confidence of at least {confidence}%."
    return text
    
def return_image(habitat_label, habitat_score, confidence):
    floraveg_url = f"https://floraveg.eu/habitat/overview/{habitat_label}"
    response = requests.get(floraveg_url)
    if response.status_code == 200:
        soup = BeautifulSoup(response.text, 'html.parser')
        img_tag = soup.find('img', src=lambda x: x and x.startswith("https://files.ibot.cas.cz/cevs/images/syntaxa/thumbs/"))
        if img_tag:
            image_url = img_tag['src']
        else:
            image_url = "https://www.salonlfc.com/wp-content/uploads/2018/01/image-not-found-scaled-1150x647.png"
    else:
        image_url = "https://www.salonlfc.com/wp-content/uploads/2018/01/image-not-found-scaled-1150x647.png"
    if habitat_score*100 < confidence:
        image_url = "https://www.salonlfc.com/wp-content/uploads/2018/01/image-not-found-scaled-1150x647.png"
    image = gr.Image(value=image_url)
    return image

def classification(text, typology, confidence, task):
    model = return_model(task)
    dataset = return_dataset()
    result = model(text)
    habitat_label = result[0]['label']
    habitat_label = dataset['train'].features['label'].names[int(habitat_label.split('_')[1])]
    habitat_score = result[0]['score']
    formatted_output = return_text(habitat_label, habitat_score, confidence)
    image_output = return_image(habitat_label, habitat_score, confidence)
    return formatted_output, image_output

def masking(text, task):
    model = return_model(task)
    masked_text = text + ', [MASK] [MASK]'
    pred = model(masked_text, top_k=1)
    new_species = [pred[i][0]['token_str'] for i in range(len(pred))]
    text = text + ', ' + ' '.join(new_species)
    image = gr.Image(value="https://www.salonlfc.com/wp-content/uploads/2018/01/image-not-found-scaled-1150x647.png")
    return text, image

def plantbert(text, typology, confidence, task):
    if task == "classification":
        formatted_output, image_output = classification(text, typology, confidence, task)
    else:
        formatted_output, image_output = masking(text, task)
    return formatted_output, image_output

inputs=[
    gr.Textbox(lines=2, label="Species", placeholder="Enter a list of comma-separated binomial names here."),
    gr.Dropdown(["EUNIS"], value="EUNIS", label="Typology", info="Will add more typologies later!"),
    gr.Slider(0, 100, value=90, label="Confidence", info="Choose the level of confidence for the prediction."),
    gr.Radio(["classification", "masking"], value="classification", label="Task", info="Which task to choose?")
]

outputs=[
    gr.Textbox(lines=2, label="Vegetation Plot Classification Result"),
    "image"
]

title="Pl@ntBERT"

description="Vegetation Plot Classification: enter the species found in a vegetation plot and see its EUNIS habitat!"

examples=[
    ["sparganium erectum, calystegia sepium, persicaria amphibia", "EUNIS", 90, "classification"],
    ["thinopyrum junceum, cakile maritima", "EUNIS", 90, "masking"]
]

io = gr.Interface(fn=plantbert, 
                         inputs=inputs, 
                         outputs=outputs,
                         title=title,
                         description=description,
                         examples=examples)

io.launch()