File size: 3,686 Bytes
6631b0d
 
 
 
 
ec04505
6631b0d
 
 
 
 
 
 
 
 
c9f086f
 
6631b0d
 
 
 
 
 
 
ec04505
 
6631b0d
 
 
 
 
 
 
 
 
 
 
 
 
 
c9f086f
6631b0d
 
 
 
 
c9f086f
6631b0d
 
 
 
c9f086f
ec04505
6631b0d
 
 
 
c9f086f
6631b0d
 
 
c9f086f
6631b0d
 
 
 
c9f086f
 
 
 
 
 
 
 
 
 
 
 
 
6631b0d
 
 
 
 
d8d0a27
 
6631b0d
 
 
 
 
 
 
 
 
 
da36317
6631b0d
c9f086f
 
 
6631b0d
 
 
 
 
c9f086f
6631b0d
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import gradio as gr
import numpy as np
import matplotlib.pyplot as plt
from copy import deepcopy
import pandas as pd
from PIL import Image


def get_index_of_element_containing_word(lst, word):
    # Create a list of indices where the word is found in the element
    indices = [i for i, element in enumerate(lst) if word.lower() in element.lower()]
    # Return the first index found, or -1 if the word is not found in any element
    return indices[0] if indices else -1

pred_global = None
alpha_global = 0.5
alpha_image = None

stl_preds = np.load("stl_species.npy")
df = pd.read_csv("gbif_full_filtered.csv")
obs = df.drop_duplicates(subset=["species"])["species"].tolist()
obs = list(sorted(obs))
del df

stl_base = Image.open("stl_base.png").convert("RGB")


def update_fn(val):
    if val=="Class":
        return gr.Dropdown(label="Name", choices=class_list, interactive=True)
    elif val=="Order":
        return gr.Dropdown(label="Name", choices=order_list, interactive=True)
    elif val=="Family":
        return gr.Dropdown(label="Name", choices=family_list, interactive=True)
    elif val=="Genus":
        return gr.Dropdown(label="Name", choices=genus_list, interactive=True)
    elif val=="Species":
        return gr.Dropdown(label="Name", choices=obs, interactive=True)

def text_fn(taxon, name):
    global pred_global, alpha_global, alpha_image

    species_index = get_index_of_element_containing_word(obs, name)
    preds = np.flip(stl_preds[:, species_index].reshape(510, 510), 1)

    pred_global = preds
    alpha_image = preds
    cmap = plt.get_cmap('plasma')

    rgba_img = cmap(preds)
    rgb_img = np.delete(rgba_img, 3, 2)
    blend = Image.blend(stl_base, Image.fromarray((rgb_img * 255).astype(np.uint8)), alpha_global)
    rgb_img = np.array(blend)
    #return gr.Image(preds, label="Predicted Heatmap", visible=True)
    return rgb_img

def thresh_fn(val):
    global pred_global, alpha_global, alpha_image
    preds = deepcopy(pred_global)
    preds[preds<val] = 0
    preds[preds>=val] = 1
    alpha_image = deepcopy(preds)
    cmap = plt.get_cmap('plasma')

    rgba_img = cmap(preds)
    rgb_img = np.delete(rgba_img, 3, 2)
    blend = Image.blend(stl_base, Image.fromarray((rgb_img * 255).astype(np.uint8)), alpha_global)
    rgb_img = np.array(blend)
    return rgb_img

def alpha_fn(val):
    global pred_global, alpha_global, alpha_image
    alpha_global = val
    preds = deepcopy(alpha_image)
    cmap = plt.get_cmap('plasma')
    rgba_img = cmap(preds)
    rgb_img = np.delete(rgba_img, 3, 2)
    blend = Image.blend(stl_base, Image.fromarray((rgb_img * 255).astype(np.uint8)), alpha_global)
    rgb_img = np.array(blend)
    return rgb_img

with gr.Blocks() as demo:
    gr.Markdown(
    """
    # St Louis Species Distribution Model!
    This model predicts the distribution of species based on geographic, and satellite image features.
    """)
    with gr.Row():
        inp = gr.Dropdown(label="Taxonomic Hierarchy", choices=["Species"])
        out = gr.Dropdown(label="Name", interactive=True)
        inp.change(update_fn, inp, out)
    
    with gr.Row():
        check_button = gr.Button("Run Model")
    
    with gr.Row():
        slider = gr.Slider(minimum=0, maximum=1, step=0.01, value=0.5, label="Confidence Threshold")
    
    with gr.Row():
        alpha = gr.Slider(minimum=0, maximum=1, step=0.01, value=0.5, label="Image Transparency")
    
    with gr.Row():
        pred = gr.Image(label="Predicted Heatmap", visible=True)
    
    check_button.click(text_fn, inputs=[inp, out], outputs=[pred])
    slider.change(thresh_fn, slider, outputs=pred)
    alpha.change(alpha_fn, alpha, outputs=pred)
    
demo.launch()