File size: 6,556 Bytes
a5316e5
 
d563836
 
544f914
a5316e5
5f4434d
8da738a
7145ecb
544f914
 
 
6176ef8
f37b5da
d563836
c8ce48e
d563836
 
 
 
 
 
6176ef8
 
 
3fcba4f
6176ef8
 
 
ccf126e
bb9c09b
ccf126e
 
 
 
 
 
 
 
 
 
 
1e09a50
ccf126e
 
 
b1a0d53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c82104c
b1a0d53
6f59e3c
aa09a05
8da738a
 
 
 
544f914
 
8da738a
3fcba4f
6176ef8
b8df8bd
b1a0d53
aa09a05
24390e2
 
 
 
f2c857b
24390e2
aa09a05
 
 
 
f80db47
aa09a05
f80db47
aa09a05
 
f80db47
 
aa09a05
24390e2
f2c857b
aa09a05
24390e2
 
 
 
f2c857b
24390e2
3fcba4f
3c63477
6176ef8
 
5282aca
7145ecb
1e09a50
f30d0ea
5282aca
1e09a50
b8df8bd
53e71ae
142304a
 
53e71ae
 
 
142304a
d3c40d6
142304a
53e71ae
5282aca
1e09a50
5282aca
142304a
3cb6c3b
 
 
142304a
d3c40d6
142304a
a5316e5
142304a
 
a5316e5
7145ecb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import gradio as gr
from transformers import pipeline
import requests
from bs4 import BeautifulSoup
import pandas as pd

# Initialize models
classification_model = pipeline("text-classification", model="models/text_classification_model", tokenizer="models/text_classification_model", top_k=5)
mask_model = pipeline("fill-mask", model="models/fill_mask_model", tokenizer="models/fill_mask_model", top_k=100)

# Load data
eunis_habitats = pd.read_excel('data/eunis_habitats.xlsx')
    
def return_habitat_image(habitat_label):
    floraveg_url = f"https://floraveg.eu/habitat/overview/{habitat_label}"
    response = requests.get(floraveg_url)
    if response.status_code == 200:
        soup = BeautifulSoup(response.text, 'html.parser')
        img_tag = soup.find('img', src=lambda x: x and x.startswith("https://files.ibot.cas.cz/cevs/images/syntaxa/thumbs/"))
        if img_tag:
            image_url = img_tag['src']
        else:
            image_url = "https://www.salonlfc.com/wp-content/uploads/2018/01/image-not-found-scaled-1150x647.png"
    else:
        image_url = "https://www.salonlfc.com/wp-content/uploads/2018/01/image-not-found-scaled-1150x647.png"
    image_url = "https://www.commissionoceanindien.org/wp-content/uploads/2018/07/plantnet.jpg"  # While we don't have the rights
    image = gr.Image(value=image_url)
    return image

def return_species_image(species):
    species = species.capitalize()
    floraveg_url = f"https://floraveg.eu/taxon/overview/{species}"
    response = requests.get(floraveg_url)
    if response.status_code == 200:
        soup = BeautifulSoup(response.text, 'html.parser')
        img_tag = soup.find('img', src=lambda x: x and x.startswith("https://files.ibot.cas.cz/cevs/images/taxa/large/"))
        if img_tag:
            image_url = img_tag['src']
        else:
            image_url = "https://www.salonlfc.com/wp-content/uploads/2018/01/image-not-found-scaled-1150x647.png"
    else:
        image_url = "https://www.salonlfc.com/wp-content/uploads/2018/01/image-not-found-scaled-1150x647.png"
    image_url = "https://www.commissionoceanindien.org/wp-content/uploads/2018/07/plantnet.jpg"
    image = gr.Image(value=image_url)
    return image

def gbif_normalization(text):
    base = "https://api.gbif.org/v1"
    api = "species"
    function = "match"
    parameter = "name"
    url = f"{base}/{api}/{function}?{parameter}="
    all_species = text.split(',')
    all_species = [species.strip() for species in all_species]
    species_gbif = []
    for species in all_species:
        url = url.replace(url.partition('name')[2], f'={species}')
        r = requests.get(url)
        r = r.json()
        if 'species' in r:
            r = r["species"]
        else:
            r = species
        species_gbif.append(r)
    text = ", ".join(species_gbif)
    text = text.lower()
    return text

def classification(text, k):
    text = gbif_normalization(text)
    result = classification_model(text)
    habitat_labels = [res['label'] for res in result[0][:k]]
    if k == 1:
        text = f"This vegetation plot belongs to the habitat {habitat_labels[0]}."
    else:
        text = f"This vegetation plot belongs to the habitats {', '.join(habitat_labels[:-1])} and {habitat_labels[-1]}."
    habitat_name = eunis_habitats[eunis_habitats['EUNIS 2020 code'] == habitat_labels[0]]['EUNIS-2021 habitat name'].values[0]
    text += f"\nThe most likely habitat is {habitat_name} (see image below)."
    image_output = return_habitat_image(habitat_labels[0])
    return text, image_output

def masking(text):
    text = gbif_normalization(text)
    text_split = text.split(', ')
    
    max_score = 0
    best_prediction = None
    best_position = None
    best_sentence = None

    for i in range(len(text_split) + 1):
        masked_text = ', '.join(text_split[:i] + ['[MASK]'] + text_split[i:])
        
        j = 0
        while True:
            prediction = mask_model(masked_text)[j]
            species = prediction['token_str']
            if species in text_split:
                j += 1
            else:
                break

        score = prediction['score']
        sentence = prediction['sequence']

        if score > max_score:
            max_score = score
            best_prediction = species
            best_position = i
            best_sentence = sentence
    
    text = f"The most likely missing species is {best_prediction} (position {best_position}).\nThe new vegetation plot is {best_sentence}."
    image = return_species_image(best_prediction)
    return text, image

with gr.Blocks() as demo:

    gr.Markdown("""<h1 style="text-align: center;">Pl@ntBERT</h1>""")
    
    with gr.Tab("Vegetation plot classification"):
        gr.Markdown("""<h3 style="text-align: center;">Classification of vegetation plots!</h3>""")
        with gr.Row():
            with gr.Column():
                species_classification = gr.Textbox(lines=2, label="Species", placeholder="Enter a list of comma-separated binomial names here.")
                k_classification = gr.Slider(1, 5, value=1, label="Top-k", info="Choose the number of top habitats to display.")
            with gr.Column():
                text_output_1 = gr.Textbox()
                text_output_2 = gr.Image()
        button_classification = gr.Button("Classify")
        gr.Markdown("""<h5 style="text-align: center;">An example of input</h5>""")
        gr.Examples([["sparganium erectum, calystegia sepium, persicaria amphibia", 1]], [species_classification, k_classification], [text_output_1, text_output_2], classification, True)
        
    with gr.Tab("Missing species finding"):
        gr.Markdown("""<h3 style="text-align: center;">Finding the missing species!</h3>""")
        with gr.Row():
            species_masking = gr.Textbox(lines=2, label="Species", placeholder="Enter a list of comma-separated binomial names here.")
            with gr.Column():
                image_output_1 = gr.Textbox()
                image_output_2 = gr.Image()
        button_masking = gr.Button("Find")
        gr.Markdown("""<h5 style="text-align: center;">An example of input</h5>""")
        gr.Examples([["vaccinium myrtillus, dryopteris dilatata, molinia caerulea"]], [species_masking], [image_output_1, image_output_2], masking, True)

    button_classification.click(classification, inputs=[species_classification, k_classification], outputs=[text_output_1, text_output_2])
    button_masking.click(masking, inputs=[species_masking], outputs=[image_output_1, image_output_2])

demo.launch()