File size: 6,253 Bytes
a5316e5
 
d563836
 
a5316e5
5f4434d
8da738a
7145ecb
6176ef8
f37b5da
d563836
c8ce48e
d563836
 
 
 
 
 
6176ef8
 
 
3fcba4f
6176ef8
 
 
ccf126e
bb9c09b
ccf126e
 
 
 
 
 
 
 
 
 
 
1e09a50
ccf126e
 
 
b1a0d53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c82104c
b1a0d53
6f59e3c
aa09a05
8da738a
 
 
 
 
3fcba4f
6176ef8
b8df8bd
b1a0d53
aa09a05
24390e2
 
 
 
f2c857b
24390e2
aa09a05
 
 
 
 
 
f80db47
aa09a05
f80db47
aa09a05
 
f80db47
 
aa09a05
24390e2
f2c857b
aa09a05
24390e2
 
 
 
 
f2c857b
24390e2
3fcba4f
3c63477
6176ef8
 
5282aca
7145ecb
1e09a50
f30d0ea
5282aca
1e09a50
b8df8bd
53e71ae
 
7145ecb
53e71ae
 
 
5282aca
d3c40d6
7145ecb
53e71ae
5282aca
1e09a50
5282aca
3cb6c3b
 
 
 
5282aca
d3c40d6
6b0aebb
a5316e5
7f326dc
ca9ec8b
a5316e5
7145ecb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import gradio as gr
from transformers import pipeline
import requests
from bs4 import BeautifulSoup

# Initialize models
classification_model = pipeline("text-classification", model="models/text_classification_model", tokenizer="models/text_classification_model", top_k=5)
mask_model = pipeline("fill-mask", model="models/fill_mask_model", tokenizer="models/fill_mask_model", top_k=100)
    
def return_habitat_image(habitat_label):
    floraveg_url = f"https://floraveg.eu/habitat/overview/{habitat_label}"
    response = requests.get(floraveg_url)
    if response.status_code == 200:
        soup = BeautifulSoup(response.text, 'html.parser')
        img_tag = soup.find('img', src=lambda x: x and x.startswith("https://files.ibot.cas.cz/cevs/images/syntaxa/thumbs/"))
        if img_tag:
            image_url = img_tag['src']
        else:
            image_url = "https://www.salonlfc.com/wp-content/uploads/2018/01/image-not-found-scaled-1150x647.png"
    else:
        image_url = "https://www.salonlfc.com/wp-content/uploads/2018/01/image-not-found-scaled-1150x647.png"
    image_url = "https://www.commissionoceanindien.org/wp-content/uploads/2018/07/plantnet.jpg"  # While we don't have the rights
    image = gr.Image(value=image_url)
    return image

def return_species_image(species):
    species = species.capitalize()
    floraveg_url = f"https://floraveg.eu/taxon/overview/{species}"
    response = requests.get(floraveg_url)
    if response.status_code == 200:
        soup = BeautifulSoup(response.text, 'html.parser')
        img_tag = soup.find('img', src=lambda x: x and x.startswith("https://files.ibot.cas.cz/cevs/images/taxa/large/"))
        if img_tag:
            image_url = img_tag['src']
        else:
            image_url = "https://www.salonlfc.com/wp-content/uploads/2018/01/image-not-found-scaled-1150x647.png"
    else:
        image_url = "https://www.salonlfc.com/wp-content/uploads/2018/01/image-not-found-scaled-1150x647.png"
    image_url = "https://www.commissionoceanindien.org/wp-content/uploads/2018/07/plantnet.jpg"
    image = gr.Image(value=image_url)
    return image

def gbif_normalization(text):
    base = "https://api.gbif.org/v1"
    api = "species"
    function = "match"
    parameter = "name"
    url = f"{base}/{api}/{function}?{parameter}="
    all_species = text.split(',')
    all_species = [species.strip() for species in all_species]
    species_gbif = []
    for species in all_species:
        url = url.replace(url.partition('name')[2], f'={species}')
        r = requests.get(url)
        r = r.json()
        if 'species' in r:
            r = r["species"]
        else:
            r = species
        species_gbif.append(r)
    text = ", ".join(species_gbif)
    text = text.lower()
    return text

def classification(text, k):
    text = gbif_normalization(text)
    result = classification_model(text)
    habitat_labels = [res['label'] for res in result[0][:k]]
    if k == 1:
        text = f"This vegetation plot belongs to the habitat {habitat_labels[0]}."
    else:
        text = f"This vegetation plot belongs to the habitats {', '.join(habitat_labels[:-1])} and {habitat_labels[-1]}."
    image_output = return_habitat_image(habitat_labels[0])
    return text, image_output

def masking(text):
    text = gbif_normalization(text)
    text_split = text.split(', ')
    
    max_score = 0
    best_prediction = None
    best_position = None
    best_sentence = None

    # Loop through each position in the sentence
    for i in range(len(text_split) + 1):
        # Create masked text
        masked_text = ', '.join(text_split[:i] + ['[MASK]'] + text_split[i:])
        
        j = 0
        while True:
            prediction = mask_model(masked_text)[j]
            species = prediction['token_str']
            if species in text_split:
                j += 1
            else:
                break

        score = prediction['score']
        sentence = prediction['sequence']

        # Update best prediction and position if score is higher
        if score > max_score:
            max_score = score
            best_prediction = species
            best_position = i
            best_sentence = sentence
    
    text = f"The most likely missing species is {best_prediction} (position {best_position}).\nThe new vegetation plot is {best_sentence}."
    image = return_species_image(best_prediction)
    return text, image

with gr.Blocks() as demo:

    gr.Markdown("""<h1 style="text-align: center;">Pl@ntBERT</h1>""")
    
    with gr.Tab("Vegetation plot classification"):
        gr.Markdown("""<h3 style="text-align: center;">Classification of vegetation plots!</h3>""")
        with gr.Row():
            with gr.Column():
                species = gr.Textbox(lines=2, label="Species", placeholder="Enter a list of comma-separated binomial names here.")
                top_k = gr.Slider(1, 5, value=1, label="Top-k", info="Choose between 1 and 5.")
            with gr.Column():
                text_output_1 = gr.Textbox()
                text_output_2 = gr.Image()
        text_button = gr.Button("Classify")
        gr.Markdown("""<h5 style="text-align: center;">An example of input</h5>""")
        gr.Examples([["sparganium erectum, calystegia sepium, persicaria amphibia", 1]], [species, top_k], [text_output_1, text_output_2], classification, True)
        
    with gr.Tab("Missing species finding"):
        gr.Markdown("""<h3 style="text-align: center;">Finding the missing species!</h3>""")
        with gr.Row():
            species_2 = gr.Textbox(lines=2, label="Species", placeholder="Enter a list of comma-separated binomial names here.")
            with gr.Column():
                image_output_1 = gr.Textbox()
                image_output_2 = gr.Image()
        image_button = gr.Button("Find")
        gr.Markdown("""<h5 style="text-align: center;">An example of input</h5>""")
        gr.Examples([["vaccinium myrtillus, dryopteris dilatata, molinia caerulea"]], [species_2], [image_output_1, image_output_2], masking, True)

    text_button.click(classification, inputs=[species], outputs=[text_output_1, text_output_2])
    image_button.click(masking, inputs=[species_2], outputs=[image_output_1, image_output_2])

demo.launch()