File size: 2,678 Bytes
6631b0d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import gradio as gr
import numpy as np
import matplotlib.pyplot as plt
from copy import deepcopy
import pandas as pd


def get_index_of_element_containing_word(lst, word):
    # Create a list of indices where the word is found in the element
    indices = [i for i, element in enumerate(lst) if word.lower() in element.lower()]
    # Return the first index found, or -1 if the word is not found in any element
    return indices[0] if indices else -1

pred_global = None

stl_preds = np.load("stl_species.npy")
df = pd.read_csv("gbif_full_filtered.csv")
obs = df.drop_duplicates(subset=["species"])["species"].tolist()
obs = list(sorted(obs))
del df


def update_fn(val):
    if val=="Class":
        return gr.Dropdown(label="Name", choices=class_list, interactive=True)
    elif val=="Order":
        return gr.Dropdown(label="Name", choices=order_list, interactive=True)
    elif val=="Family":
        return gr.Dropdown(label="Name", choices=family_list, interactive=True)
    elif val=="Genus":
        return gr.Dropdown(label="Name", choices=genus_list, interactive=True)
    elif val=="Species":
        return gr.Dropdown(label="Name", choices=obs, interactive=True)

def text_fn(taxon, name):
    global pred_global

    species_index = get_index_of_element_containing_word(obs, name)
    preds = np.flip(stl_preds[:, species_index].reshape(510, 510), 1)

    pred_global = preds
    cmap = plt.get_cmap('plasma')

    rgba_img = cmap(preds)
    rgb_img = np.delete(rgba_img, 3, 2)
    #return gr.Image(preds, label="Predicted Heatmap", visible=True)
    return rgb_img

def thresh_fn(val):
    global pred_global
    preds = deepcopy(pred_global)
    preds[preds<val] = 0
    preds[preds>=val] = 1
    cmap = plt.get_cmap('plasma')

    rgba_img = cmap(preds)
    rgb_img = np.delete(rgba_img, 3, 2)
    return rgb_img

with gr.Blocks() as demo:
    gr.Markdown(
    """
    # Hierarchical Species Distribution Model!
    This model predicts the distribution of species based on geographic, environmental, and natural language features.
    """)
    with gr.Row():
        inp = gr.Dropdown(label="Taxonomic Hierarchy", choices=["Species"])
        out = gr.Dropdown(label="Name", interactive=True)
        inp.change(update_fn, inp, out)
    
    with gr.Row():
        check_button = gr.Button("Run Model")
    
    with gr.Row():
        slider = gr.Slider(minimum=0, maximum=1, step=0.01, default=0.5, label="Confidence Threshold")
    
    with gr.Row():
        pred = gr.Image(label="Predicted Heatmap", visible=True)
    
    check_button.click(text_fn, inputs=[inp, out], outputs=[pred])
    slider.change(thresh_fn, slider, outputs=pred)
    
demo.launch()