StLouis-SDM / app.py
Vishu26's picture
wip
c9f086f
raw
history blame
3.69 kB
import gradio as gr
import numpy as np
import matplotlib.pyplot as plt
from copy import deepcopy
import pandas as pd
from PIL import Image
def get_index_of_element_containing_word(lst, word):
# Create a list of indices where the word is found in the element
indices = [i for i, element in enumerate(lst) if word.lower() in element.lower()]
# Return the first index found, or -1 if the word is not found in any element
return indices[0] if indices else -1
pred_global = None
alpha_global = 0.5
alpha_image = None
stl_preds = np.load("stl_species.npy")
df = pd.read_csv("gbif_full_filtered.csv")
obs = df.drop_duplicates(subset=["species"])["species"].tolist()
obs = list(sorted(obs))
del df
stl_base = Image.open("stl_base.png").convert("RGB")
def update_fn(val):
if val=="Class":
return gr.Dropdown(label="Name", choices=class_list, interactive=True)
elif val=="Order":
return gr.Dropdown(label="Name", choices=order_list, interactive=True)
elif val=="Family":
return gr.Dropdown(label="Name", choices=family_list, interactive=True)
elif val=="Genus":
return gr.Dropdown(label="Name", choices=genus_list, interactive=True)
elif val=="Species":
return gr.Dropdown(label="Name", choices=obs, interactive=True)
def text_fn(taxon, name):
global pred_global, alpha_global, alpha_image
species_index = get_index_of_element_containing_word(obs, name)
preds = np.flip(stl_preds[:, species_index].reshape(510, 510), 1)
pred_global = preds
alpha_image = preds
cmap = plt.get_cmap('plasma')
rgba_img = cmap(preds)
rgb_img = np.delete(rgba_img, 3, 2)
blend = Image.blend(stl_base, Image.fromarray((rgb_img * 255).astype(np.uint8)), alpha_global)
rgb_img = np.array(blend)
#return gr.Image(preds, label="Predicted Heatmap", visible=True)
return rgb_img
def thresh_fn(val):
global pred_global, alpha_global, alpha_image
preds = deepcopy(pred_global)
preds[preds<val] = 0
preds[preds>=val] = 1
alpha_image = deepcopy(preds)
cmap = plt.get_cmap('plasma')
rgba_img = cmap(preds)
rgb_img = np.delete(rgba_img, 3, 2)
blend = Image.blend(stl_base, Image.fromarray((rgb_img * 255).astype(np.uint8)), alpha_global)
rgb_img = np.array(blend)
return rgb_img
def alpha_fn(val):
global pred_global, alpha_global, alpha_image
alpha_global = val
preds = deepcopy(alpha_image)
cmap = plt.get_cmap('plasma')
rgba_img = cmap(preds)
rgb_img = np.delete(rgba_img, 3, 2)
blend = Image.blend(stl_base, Image.fromarray((rgb_img * 255).astype(np.uint8)), alpha_global)
rgb_img = np.array(blend)
return rgb_img
with gr.Blocks() as demo:
gr.Markdown(
"""
# St Louis Species Distribution Model!
This model predicts the distribution of species based on geographic, and satellite image features.
""")
with gr.Row():
inp = gr.Dropdown(label="Taxonomic Hierarchy", choices=["Species"])
out = gr.Dropdown(label="Name", interactive=True)
inp.change(update_fn, inp, out)
with gr.Row():
check_button = gr.Button("Run Model")
with gr.Row():
slider = gr.Slider(minimum=0, maximum=1, step=0.01, value=0.5, label="Confidence Threshold")
with gr.Row():
alpha = gr.Slider(minimum=0, maximum=1, step=0.01, value=0.5, label="Image Transparency")
with gr.Row():
pred = gr.Image(label="Predicted Heatmap", visible=True)
check_button.click(text_fn, inputs=[inp, out], outputs=[pred])
slider.change(thresh_fn, slider, outputs=pred)
alpha.change(alpha_fn, alpha, outputs=pred)
demo.launch()