File size: 4,105 Bytes
a5316e5
 
7e0319c
d563836
 
a5316e5
6176ef8
 
 
 
 
443a3b3
a5316e5
6176ef8
 
 
 
 
 
 
 
 
 
 
 
d563836
c8ce48e
d563836
 
 
 
 
 
6176ef8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fc3586c
 
 
 
6176ef8
 
 
 
 
 
d563836
54a8fdb
1c2f25a
a5316e5
6176ef8
 
 
16ad661
 
6176ef8
 
 
 
 
 
 
 
 
 
 
a7e54a7
6176ef8
c180063
a7e54a7
a5316e5
6176ef8
 
 
 
 
a5316e5
 
7e0319c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import gradio as gr
from transformers import pipeline
from datasets import load_dataset
import requests
from bs4 import BeautifulSoup

def return_model(task):
    if task == 'classification':
        model = pipeline("text-classification", model="CesarLeblanc/test_model")
    else:
        model = pipeline("fill-mask", model="CesarLeblanc/fill_mask_model")
    return model

def return_dataset():
    dataset = load_dataset("CesarLeblanc/text_classification_dataset")
    return dataset

def return_text(habitat_label, habitat_score, confidence):
    if habitat_score*100 > confidence:
        text = f"This vegetation plot belongs to the habitat {habitat_label} with the probability {habitat_score*100:.2f}%."
    else:
        text = f"We can't assign an habitat to this vegetation plot with a confidence of at least {confidence}%."
    return text
    
def return_image(habitat_label, habitat_score, confidence):
    floraveg_url = f"https://floraveg.eu/habitat/overview/{habitat_label}"
    response = requests.get(floraveg_url)
    if response.status_code == 200:
        soup = BeautifulSoup(response.text, 'html.parser')
        img_tag = soup.find('img', src=lambda x: x and x.startswith("https://files.ibot.cas.cz/cevs/images/syntaxa/thumbs/"))
        if img_tag:
            image_url = img_tag['src']
        else:
            image_url = "https://www.salonlfc.com/wp-content/uploads/2018/01/image-not-found-scaled-1150x647.png"
    else:
        image_url = "https://www.salonlfc.com/wp-content/uploads/2018/01/image-not-found-scaled-1150x647.png"
    if habitat_score*100 < confidence:
        image_url = "https://www.salonlfc.com/wp-content/uploads/2018/01/image-not-found-scaled-1150x647.png"
    image = gr.Image(value=image_url)
    return image

def classification(text, typology, confidence, task):
    model = return_model(task)
    dataset = return_dataset()
    result = model(text)
    habitat_label = result[0]['label']
    habitat_label = dataset['train'].features['label'].names[int(habitat_label.split('_')[1])]
    habitat_score = result[0]['score']
    formatted_output = return_text(habitat_label, habitat_score, confidence)
    image_output = return_image(habitat_label, habitat_score, confidence)
    return formatted_output, image_output

def masking(text, task):
    model = return_model(task)
    masked_text = text + ', [MASK] [MASK]'
    pred = model(masked_text, top_k=1)
    new_species = [pred[i][0]['token_str'] for i in range(len(pred))]
    text = text + ', ' + ' '.join(new_species)
    image = gr.Image(value="https://www.salonlfc.com/wp-content/uploads/2018/01/image-not-found-scaled-1150x647.png")
    return text, image

def plantbert(text, typology, confidence, task):
    if task == "classification":
        formatted_output, image_output = classification(text, typology, confidence, task)
    else:
        formatted_output, image_output = masking(text, task)
    return formatted_output, image_output

inputs=[
    gr.Textbox(lines=2, label="Species", placeholder="Enter a list of comma-separated binomial names here."),
    gr.Dropdown(["EUNIS"], value="EUNIS", label="Typology", info="Will add more typologies later!"),
    gr.Slider(0, 100, value=90, label="Confidence", info="Choose the level of confidence for the prediction."),
    gr.Radio(["classification", "masking"], value="classification", label="Task", info="Which task to choose?")
]

outputs=[
    gr.Textbox(lines=2, label="Vegetation Plot Classification Result"),
    "image"
]

title="Pl@ntBERT"

description="Vegetation Plot Classification: enter the species found in a vegetation plot and see its EUNIS habitat!"

examples=[
    ["sparganium erectum, calystegia sepium, persicaria amphibia", "EUNIS", 90, "classification"],
    ["vaccinium myrtillus, dryopteris dilatata, molinia caerulea", "EUNIS", 90, "masking"]
]

io = gr.Interface(fn=plantbert, 
                         inputs=inputs, 
                         outputs=outputs,
                         title=title,
                         description=description,
                         examples=examples)

io.launch()