StLouis-SDM / app.py
Vishu26's picture
wip
cd9940f
import gradio as gr
import numpy as np
import matplotlib.pyplot as plt
from copy import deepcopy
import pandas as pd
from PIL import Image
def get_index_of_element_containing_word(lst, word):
# Create a list of indices where the word is found in the element
indices = [i for i, element in enumerate(lst) if word.lower() in element.lower()]
# Return the first index found, or -1 if the word is not found in any element
return indices[0] if indices else -1
pred_global = None
alpha_global = 0.5
alpha_image = None
stl_preds = np.load("stl_species.npy")
df = pd.read_csv("unique_species.csv")
obs = df["NameList"].tolist()
del df
stl_base = Image.open("stl_base.png").convert("RGB")
def update_fn(val):
if val=="Class":
return gr.Dropdown(label="Name", choices=class_list, interactive=True)
elif val=="Order":
return gr.Dropdown(label="Name", choices=order_list, interactive=True)
elif val=="Family":
return gr.Dropdown(label="Name", choices=family_list, interactive=True)
elif val=="Genus":
return gr.Dropdown(label="Name", choices=genus_list, interactive=True)
elif val=="Species":
return gr.Dropdown(label="Name", choices=obs, interactive=True)
def text_fn(taxon, name):
global pred_global, alpha_global, alpha_image
species_index = get_index_of_element_containing_word(obs, name)
preds = np.flip(stl_preds[:, species_index].reshape(510, 510), 1)
pred_global = preds
alpha_image = preds
cmap = plt.get_cmap('plasma')
rgba_img = cmap(preds)
rgb_img = np.delete(rgba_img, 3, 2)
blend = Image.blend(stl_base, Image.fromarray((rgb_img * 255).astype(np.uint8)), alpha_global)
rgb_img = np.array(blend)
#return gr.Image(preds, label="Predicted Heatmap", visible=True)
return rgb_img
def thresh_fn(val):
global pred_global, alpha_global, alpha_image
preds = deepcopy(pred_global)
preds[preds<val] = 0
preds[preds>=val] = 1
alpha_image = deepcopy(preds)
cmap = plt.get_cmap('plasma')
rgba_img = cmap(preds)
rgb_img = np.delete(rgba_img, 3, 2)
blend = Image.blend(stl_base, Image.fromarray((rgb_img * 255).astype(np.uint8)), alpha_global)
rgb_img = np.array(blend)
return rgb_img
def alpha_fn(val):
global pred_global, alpha_global, alpha_image
alpha_global = val
preds = deepcopy(alpha_image)
cmap = plt.get_cmap('plasma')
rgba_img = cmap(preds)
rgb_img = np.delete(rgba_img, 3, 2)
blend = Image.blend(stl_base, Image.fromarray((rgb_img * 255).astype(np.uint8)), alpha_global)
rgb_img = np.array(blend)
return rgb_img
with gr.Blocks() as demo:
gr.Markdown(
"""
# St Louis Species Distribution Model!
This model predicts the distribution of species based on geographic, and satellite image features.
""")
with gr.Row():
inp = gr.Dropdown(label="Taxonomic Hierarchy", choices=["Species"])
out = gr.Dropdown(label="Name", interactive=True)
inp.change(update_fn, inp, out)
with gr.Row():
check_button = gr.Button("Run Model")
with gr.Row():
slider = gr.Slider(minimum=0, maximum=1, step=0.01, value=0.5, label="Confidence Threshold")
with gr.Row():
alpha = gr.Slider(minimum=0, maximum=1, step=0.01, value=0.5, label="Image Transparency")
with gr.Row():
pred = gr.Image(label="Predicted Heatmap", visible=True)
check_button.click(text_fn, inputs=[inp, out], outputs=[pred])
slider.change(thresh_fn, slider, outputs=pred)
alpha.change(alpha_fn, alpha, outputs=pred)
demo.launch()