StLouis-SDM / app.py
Vishu26's picture
wip
ec04505
raw
history blame
2.85 kB
import gradio as gr
import numpy as np
import matplotlib.pyplot as plt
from copy import deepcopy
import pandas as pd
from PIL import Image
def get_index_of_element_containing_word(lst, word):
# Create a list of indices where the word is found in the element
indices = [i for i, element in enumerate(lst) if word.lower() in element.lower()]
# Return the first index found, or -1 if the word is not found in any element
return indices[0] if indices else -1
pred_global = None
stl_preds = np.load("stl_species.npy")
df = pd.read_csv("gbif_full_filtered.csv")
obs = df.drop_duplicates(subset=["species"])["species"].tolist()
obs = list(sorted(obs))
del df
stl_base = Image.open("stl_base.png").convert("RGB")
def update_fn(val):
if val=="Class":
return gr.Dropdown(label="Name", choices=class_list, interactive=True)
elif val=="Order":
return gr.Dropdown(label="Name", choices=order_list, interactive=True)
elif val=="Family":
return gr.Dropdown(label="Name", choices=family_list, interactive=True)
elif val=="Genus":
return gr.Dropdown(label="Name", choices=genus_list, interactive=True)
elif val=="Species":
return gr.Dropdown(label="Name", choices=obs, interactive=True)
def text_fn(taxon, name):
global pred_global
species_index = get_index_of_element_containing_word(obs, name)
preds = np.flip(stl_preds[:, species_index].reshape(510, 510), 1)
pred_global = preds
cmap = plt.get_cmap('plasma')
rgba_img = cmap(preds)
rgb_img = np.delete(rgba_img, 3, 2)
blend = Image.blend(stl_base, Image.fromarray((rgb_img * 255).astype(np.uint8)), 0.5)
rgb_img = np.array(blend)
#return gr.Image(preds, label="Predicted Heatmap", visible=True)
return rgb_img
def thresh_fn(val):
global pred_global
preds = deepcopy(pred_global)
preds[preds<val] = 0
preds[preds>=val] = 1
cmap = plt.get_cmap('plasma')
rgba_img = cmap(preds)
rgb_img = np.delete(rgba_img, 3, 2)
return rgb_img
with gr.Blocks() as demo:
gr.Markdown(
"""
# St Louis Species Distribution Model!
This model predicts the distribution of species based on geographic, and satellite image features.
""")
with gr.Row():
inp = gr.Dropdown(label="Taxonomic Hierarchy", choices=["Species"])
out = gr.Dropdown(label="Name", interactive=True)
inp.change(update_fn, inp, out)
with gr.Row():
check_button = gr.Button("Run Model")
with gr.Row():
slider = gr.Slider(minimum=0, maximum=1, step=0.01, value=0.5, label="Confidence Threshold")
with gr.Row():
pred = gr.Image(label="Predicted Heatmap", visible=True)
check_button.click(text_fn, inputs=[inp, out], outputs=[pred])
slider.change(thresh_fn, slider, outputs=pred)
demo.launch()