File size: 7,039 Bytes
a5316e5
 
d563836
 
20742e4
a5316e5
5f4434d
edd2cf0
3fcba4f
6176ef8
f37b5da
d563836
c8ce48e
d563836
 
 
 
 
 
6176ef8
 
 
3fcba4f
6176ef8
 
 
ccf126e
bb9c09b
ccf126e
 
 
 
 
 
 
 
 
 
 
1e09a50
ccf126e
 
 
b1a0d53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f37b5da
b1a0d53
6f59e3c
6176ef8
3fcba4f
f37b5da
3fcba4f
6176ef8
b8df8bd
b1a0d53
24390e2
 
 
 
f2c857b
24390e2
 
 
f80db47
 
 
 
 
 
 
 
24390e2
f2c857b
24390e2
 
 
 
 
f2c857b
24390e2
 
 
 
f80db47
 
 
 
 
 
 
 
24390e2
f2c857b
24390e2
 
 
 
 
 
f2c857b
24390e2
 
 
f80db47
29188eb
 
 
 
 
 
 
24390e2
f2c857b
24390e2
 
 
 
 
f2c857b
24390e2
3fcba4f
3c63477
6176ef8
 
5282aca
1e09a50
f30d0ea
5282aca
1e09a50
b8df8bd
53e71ae
 
5f4434d
53e71ae
 
 
5282aca
d3c40d6
f37b5da
53e71ae
5282aca
1e09a50
5282aca
3cb6c3b
 
 
 
5282aca
d3c40d6
6b0aebb
a5316e5
7f326dc
ca9ec8b
a5316e5
5282aca
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
import gradio as gr
from transformers import pipeline
import requests
from bs4 import BeautifulSoup
import random

# Initialize models
classification_model = pipeline("text-classification", model="plantbert_text_classification_model", tokenizer="plantbert_text_classification_model")
mask_model = pipeline("fill-mask", model="plantbert_fill_mask_model", tokenizer="plantbert_fill_mask_model", top_k=100)
    
def return_habitat_image(habitat_label):
    floraveg_url = f"https://floraveg.eu/habitat/overview/{habitat_label}"
    response = requests.get(floraveg_url)
    if response.status_code == 200:
        soup = BeautifulSoup(response.text, 'html.parser')
        img_tag = soup.find('img', src=lambda x: x and x.startswith("https://files.ibot.cas.cz/cevs/images/syntaxa/thumbs/"))
        if img_tag:
            image_url = img_tag['src']
        else:
            image_url = "https://www.salonlfc.com/wp-content/uploads/2018/01/image-not-found-scaled-1150x647.png"
    else:
        image_url = "https://www.salonlfc.com/wp-content/uploads/2018/01/image-not-found-scaled-1150x647.png"
    image_url = "https://www.commissionoceanindien.org/wp-content/uploads/2018/07/plantnet.jpg"  # While we don't have the rights
    image = gr.Image(value=image_url)
    return image

def return_species_image(species):
    species = species.capitalize()
    floraveg_url = f"https://floraveg.eu/taxon/overview/{species}"
    response = requests.get(floraveg_url)
    if response.status_code == 200:
        soup = BeautifulSoup(response.text, 'html.parser')
        img_tag = soup.find('img', src=lambda x: x and x.startswith("https://files.ibot.cas.cz/cevs/images/taxa/large/"))
        if img_tag:
            image_url = img_tag['src']
        else:
            image_url = "https://www.salonlfc.com/wp-content/uploads/2018/01/image-not-found-scaled-1150x647.png"
    else:
        image_url = "https://www.salonlfc.com/wp-content/uploads/2018/01/image-not-found-scaled-1150x647.png"
    image_url = "https://www.commissionoceanindien.org/wp-content/uploads/2018/07/plantnet.jpg"
    image = gr.Image(value=image_url)
    return image

def gbif_normalization(text):
    base = "https://api.gbif.org/v1"
    api = "species"
    function = "match"
    parameter = "name"
    url = f"{base}/{api}/{function}?{parameter}="
    all_species = text.split(',')
    all_species = [species.strip() for species in all_species]
    species_gbif = []
    for species in all_species:
        url = url.replace(url.partition('name')[2], f'={species}')
        r = requests.get(url)
        r = r.json()
        if 'species' in r:
            r = r["species"]
        else:
            r = species
        species_gbif.append(r)
    text = ", ".join(species_gbif)
    text = text.lower()
    return text

def classification(text):
    text = gbif_normalization(text)
    result = classification_model(text)
    habitat_label = result[0]['label']
    text = f"This vegetation plot belongs to the habitat {habitat_label}."
    image_output = return_habitat_image(habitat_label)
    return text, image_output

def masking(text):
    text = gbif_normalization(text)
    
    max_score = 0
    best_prediction = None
    best_position = None
    best_sentence = None

    # Case for the first position
    masked_text = "[MASK], " + ', '.join(text.split(', '))
    i = 0
    while True:
        prediction = mask_model(masked_text)[i]
        species = prediction['token_str']
        if species in text.split(', '):
            i+=1
        else:
            break
    score = prediction['score']
    sentence = prediction['sequence']

    if score > max_score:
        max_score = score
        best_prediction = species
        best_position = 0
        best_sentence = sentence

    # Loop through each position in the middle of the sentence
    for i in range(1, len(text.split(', '))):
        masked_text = ', '.join(text.split(', ')[:i]) + ', [MASK], ' + ', '.join(text.split(', ')[i:])
        i = 0
        while True:
            prediction = mask_model(masked_text)[i]
            species = prediction['token_str']
            if species in text.split(', '):
                i+=1
            else:
                break
        score = prediction['score']
        sentence = prediction['sequence']
        
        # Update best prediction and position if score is higher
        if score > max_score:
            max_score = score
            best_prediction = species
            best_position = i
            best_sentence = sentence

    # Case for the last position
    masked_text = ', '.join(text.split(', ')) + ', [MASK]'
    i = 0
    while True:
        prediction = mask_model(masked_text)[i]
        species = prediction['token_str']
        if species in text.split(', '):
            i+=1
        else:
            break
    score = prediction['score']
    sentence = prediction['sequence']
    
    if score > max_score:
        max_score = score
        best_prediction = species
        best_position = len(text.split(', '))
        best_sentence = sentence
    
    text = f"The most likely missing species is {best_prediction} (position {best_position}).\nThe new vegetation plot is {best_sentence}."
    image = return_species_image(best_prediction)
    return text, image

with gr.Blocks() as demo:
    gr.Markdown("""<h1 style="text-align: center;">Pl@ntBERT</h1>""")
    
    with gr.Tab("Vegetation plot classification"):
        gr.Markdown("""<h3 style="text-align: center;">Classification of vegetation plots!</h3>""")
        with gr.Row():
            with gr.Column():
                species = gr.Textbox(lines=2, label="Species", placeholder="Enter a list of comma-separated binomial names here.")
                top_k = 
            with gr.Column():
                text_output_1 = gr.Textbox()
                text_output_2 = gr.Image()
        text_button = gr.Button("Classify")
        gr.Markdown("""<h5 style="text-align: center;">An example of input</h5>""")
        gr.Examples([["sparganium erectum, calystegia sepium, persicaria amphibia"]], [species], [text_output_1, text_output_2], classification, True)
        
    with gr.Tab("Missing species finding"):
        gr.Markdown("""<h3 style="text-align: center;">Finding the missing species!</h3>""")
        with gr.Row():
            species_2 = gr.Textbox(lines=2, label="Species", placeholder="Enter a list of comma-separated binomial names here.")
            with gr.Column():
                image_output_1 = gr.Textbox()
                image_output_2 = gr.Image()
        image_button = gr.Button("Find")
        gr.Markdown("""<h5 style="text-align: center;">An example of input</h5>""")
        gr.Examples([["vaccinium myrtillus, dryopteris dilatata, molinia caerulea"]], [species_2], [image_output_1, image_output_2], masking, True)

    text_button.click(classification, inputs=[species], outputs=[text_output_1, text_output_2])
    image_button.click(masking, inputs=[species_2], outputs=[image_output_1, image_output_2])

demo.launch()