File size: 5,199 Bytes
a5316e5
 
7e0319c
d563836
 
a5316e5
6176ef8
 
 
 
 
443a3b3
a5316e5
6176ef8
 
 
 
 
 
 
 
 
 
 
ccf126e
d563836
c8ce48e
d563836
 
 
 
 
 
6176ef8
 
 
 
 
f5a3585
6176ef8
 
 
ccf126e
 
 
 
 
 
 
 
 
 
 
 
 
f5a3585
ccf126e
 
 
b8df8bd
 
6176ef8
 
 
 
 
 
ccf126e
6176ef8
 
b8df8bd
 
fc3586c
 
 
ccf126e
f5a3585
ccf126e
6176ef8
 
5282aca
f30d0ea
 
5282aca
f30d0ea
b8df8bd
53e71ae
 
 
 
 
 
 
5282aca
53e71ae
5282aca
f30d0ea
5282aca
3cb6c3b
 
 
 
5282aca
a5316e5
ca9ec8b
 
a5316e5
5282aca
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import gradio as gr
from transformers import pipeline
from datasets import load_dataset
import requests
from bs4 import BeautifulSoup

def return_model(task):
    if task == 'classification':
        model = pipeline("text-classification", model="CesarLeblanc/test_model")
    else:
        model = pipeline("fill-mask", model="CesarLeblanc/fill_mask_model")
    return model

def return_dataset():
    dataset = load_dataset("CesarLeblanc/text_classification_dataset")
    return dataset

def return_text(habitat_label, habitat_score, confidence):
    if habitat_score*100 > confidence:
        text = f"This vegetation plot belongs to the habitat {habitat_label} with the probability {habitat_score*100:.2f}%."
    else:
        text = f"We can't assign an habitat to this vegetation plot with a confidence of at least {confidence}%."
    return text
    
def return_habitat_image(habitat_label, habitat_score, confidence):
    floraveg_url = f"https://floraveg.eu/habitat/overview/{habitat_label}"
    response = requests.get(floraveg_url)
    if response.status_code == 200:
        soup = BeautifulSoup(response.text, 'html.parser')
        img_tag = soup.find('img', src=lambda x: x and x.startswith("https://files.ibot.cas.cz/cevs/images/syntaxa/thumbs/"))
        if img_tag:
            image_url = img_tag['src']
        else:
            image_url = "https://www.salonlfc.com/wp-content/uploads/2018/01/image-not-found-scaled-1150x647.png"
    else:
        image_url = "https://www.salonlfc.com/wp-content/uploads/2018/01/image-not-found-scaled-1150x647.png"
    if habitat_score*100 < confidence:
        image_url = "https://www.salonlfc.com/wp-content/uploads/2018/01/image-not-found-scaled-1150x647.png"
    image_url = "https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcQoQPZxckzsiQyFh9w7Z7aJ38d23lvLQFj4QemMFjw2lvc18iQrzDYf7EzzmD7cFdfbsZU&usqp=CAU"
    image = gr.Image(value=image_url)
    return image

def return_species_image(species):
    species = species[0].capitalize() + species[1:]
    floraveg_url = f"https://floraveg.eu/taxon/overview/{species}"
    response = requests.get(floraveg_url)
    if response.status_code == 200:
        soup = BeautifulSoup(response.text, 'html.parser')
        img_tag = soup.find('img', src=lambda x: x and x.startswith("https://files.ibot.cas.cz/cevs/images/taxa/large/"))
        if img_tag:
            image_url = img_tag['src']
        else:
            image_url = "https://www.salonlfc.com/wp-content/uploads/2018/01/image-not-found-scaled-1150x647.png"
    else:
        image_url = "https://www.salonlfc.com/wp-content/uploads/2018/01/image-not-found-scaled-1150x647.png"
    image_url = "https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcQoQPZxckzsiQyFh9w7Z7aJ38d23lvLQFj4QemMFjw2lvc18iQrzDYf7EzzmD7cFdfbsZU&usqp=CAU"
    image = gr.Image(value=image_url)
    return image

def classification(text, typology, confidence):
    model = return_model("classification")
    dataset = return_dataset()
    result = model(text)
    habitat_label = result[0]['label']
    habitat_label = dataset['train'].features['label'].names[int(habitat_label.split('_')[1])]
    habitat_score = result[0]['score']
    formatted_output = return_text(habitat_label, habitat_score, confidence)
    image_output = return_habitat_image(habitat_label, habitat_score, confidence)
    return formatted_output, image_output

def masking(text):
    model = return_model("masking")
    masked_text = text + ', [MASK] [MASK]'
    pred = model(masked_text, top_k=1)
    new_species = [pred[i][0]['token_str'] for i in range(len(pred))]
    new_species = ' '.join(new_species)
    text = f"The last species from this vegetation plot is probably {new_species}."
    image = return_species_image(new_species)
    return text, image

with gr.Blocks() as demo:
    gr.Markdown("""# Pl@ntBERT""")
    
    with gr.Tab("Vegetation plot classification"):
        gr.Markdown("""Classify vegetation plots!""")
        with gr.Row():
            with gr.Column():
                species = gr.Textbox(lines=2, label="Species", placeholder="Enter a list of comma-separated binomial names here.")
                typology = gr.Dropdown(["EUNIS"], value="EUNIS", label="Typology", info="Will add more typologies later!")
                confidence = gr.Slider(0, 100, value=90, label="Confidence", info="Choose the level of confidence for the prediction.")
            with gr.Column():
                text_output_1 = gr.Textbox()
                text_output_2 = gr.Image()
        text_button = gr.Button("Classify")
        
    with gr.Tab("Missing species finding"):
        gr.Markdown("""Find the missing species!""")
        with gr.Row():
            species_2 = gr.Textbox(lines=2, label="Species", placeholder="Enter a list of comma-separated binomial names here.")
            with gr.Column():
                image_output_1 = gr.Textbox()
                image_output_2 = gr.Image()
        image_button = gr.Button("Find")

    text_button.click(classification, inputs=[species, typology, confidence], outputs=[text_output_1, text_output_2])
    image_button.click(masking, inputs=[species_2], outputs=[image_output_1, image_output_2])

demo.launch()