import gradio as gr import numpy as np import matplotlib.pyplot as plt from copy import deepcopy import pandas as pd from PIL import Image def get_index_of_element_containing_word(lst, word): # Create a list of indices where the word is found in the element indices = [i for i, element in enumerate(lst) if word.lower() in element.lower()] # Return the first index found, or -1 if the word is not found in any element return indices[0] if indices else -1 pred_global = None alpha_global = 0.5 alpha_image = None stl_preds = np.load("stl_species.npy") df = pd.read_csv("unique_species.csv") obs = df["NameList"].tolist() del df stl_base = Image.open("stl_base.png").convert("RGB") def update_fn(val): if val=="Class": return gr.Dropdown(label="Name", choices=class_list, interactive=True) elif val=="Order": return gr.Dropdown(label="Name", choices=order_list, interactive=True) elif val=="Family": return gr.Dropdown(label="Name", choices=family_list, interactive=True) elif val=="Genus": return gr.Dropdown(label="Name", choices=genus_list, interactive=True) elif val=="Species": return gr.Dropdown(label="Name", choices=obs, interactive=True) def text_fn(taxon, name): global pred_global, alpha_global, alpha_image species_index = get_index_of_element_containing_word(obs, name) preds = np.flip(stl_preds[:, species_index].reshape(510, 510), 1) pred_global = preds alpha_image = preds cmap = plt.get_cmap('plasma') rgba_img = cmap(preds) rgb_img = np.delete(rgba_img, 3, 2) blend = Image.blend(stl_base, Image.fromarray((rgb_img * 255).astype(np.uint8)), alpha_global) rgb_img = np.array(blend) #return gr.Image(preds, label="Predicted Heatmap", visible=True) return rgb_img def thresh_fn(val): global pred_global, alpha_global, alpha_image preds = deepcopy(pred_global) preds[preds=val] = 1 alpha_image = deepcopy(preds) cmap = plt.get_cmap('plasma') rgba_img = cmap(preds) rgb_img = np.delete(rgba_img, 3, 2) blend = Image.blend(stl_base, Image.fromarray((rgb_img * 255).astype(np.uint8)), alpha_global) rgb_img = np.array(blend) return rgb_img def alpha_fn(val): global pred_global, alpha_global, alpha_image alpha_global = val preds = deepcopy(alpha_image) cmap = plt.get_cmap('plasma') rgba_img = cmap(preds) rgb_img = np.delete(rgba_img, 3, 2) blend = Image.blend(stl_base, Image.fromarray((rgb_img * 255).astype(np.uint8)), alpha_global) rgb_img = np.array(blend) return rgb_img with gr.Blocks() as demo: gr.Markdown( """ # St Louis Species Distribution Model! This model predicts the distribution of species based on geographic, and satellite image features. """) with gr.Row(): inp = gr.Dropdown(label="Taxonomic Hierarchy", choices=["Species"]) out = gr.Dropdown(label="Name", interactive=True) inp.change(update_fn, inp, out) with gr.Row(): check_button = gr.Button("Run Model") with gr.Row(): slider = gr.Slider(minimum=0, maximum=1, step=0.01, value=0.5, label="Confidence Threshold") with gr.Row(): alpha = gr.Slider(minimum=0, maximum=1, step=0.01, value=0.5, label="Image Transparency") with gr.Row(): pred = gr.Image(label="Predicted Heatmap", visible=True) check_button.click(text_fn, inputs=[inp, out], outputs=[pred]) slider.change(thresh_fn, slider, outputs=pred) alpha.change(alpha_fn, alpha, outputs=pred) demo.launch()