import gradio as gr import numpy as np import torch import torch.nn as nn from einops import rearrange import matplotlib.pyplot as plt from copy import deepcopy pred_global = None class Attn(nn.Module): def __init__(self, dim, dim_text, heads = 16, dim_head = 64): super().__init__() self.scale = dim_head ** -0.5 self.heads = heads hidden_dim = dim_head * heads self.to_q = nn.Conv2d(dim, hidden_dim, 4, bias = False, stride=4) self.to_kv = nn.Linear(dim_text, hidden_dim * 2, bias=False) #self.norm = nn.LayerNorm(dim) self.to_out = nn.Linear(hidden_dim, dim) def forward(self, x, text): b, c, h, w = x.shape kv = self.to_kv(text).chunk(2, dim = -1) k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), kv) q = self.to_q(x) q = rearrange(q, 'b (h c) x y -> b h (x y) c', h=self.heads) #attn = torch.einsum('bhnd,bhed->bhne',q,k) * self.scale attn = torch.matmul(q, k.transpose(-1, -2)) * self.scale attn = attn.softmax(dim=-1) #print(attn.shape) out = torch.matmul(attn, v) out = rearrange(out, 'b h n d -> b n (h d)') #print(out.shape) return self.to_out(out) class RangeModel(nn.Module): def __init__(self): super(RangeModel, self).__init__() self.cross_attn = Attn(128, 8192) self.upsample = nn.Upsample(scale_factor=4, mode='bilinear') self.out = nn.Conv2d(128, 1, 1, bias=False) self.x = None def forward(self, text): x = self.cross_attn(self.x, text) x = rearrange(x, 'b (h w) d -> b d h w', h=225) x = self.upsample(x) x = self.out(x) return x model = RangeModel() model.load_state_dict(torch.load("model/demo_model.pt", map_location=torch.device('cpu'))) pos_embed = np.load("data/pos_embeds_model.npy", allow_pickle=True) model.x = torch.tensor(pos_embed).float() model.eval() species = np.load("data/species_70b.npy", allow_pickle=True) clas = np.load("data/class_70b.npy", allow_pickle=True) order = np.load("data/order_70b.npy", allow_pickle=True) #genus = np.load("genus_70b.npy") family = np.load("data/family_70b.npy", allow_pickle=True) species_list = list(species[()].keys()) class_list = list(clas[()].keys()) order_list = list(order[()].keys()) #genus_list = list(genus[()].keys()) family_list = list(family[()].keys()) def update_fn(val): if val=="Class": return gr.Dropdown(label="Name", choices=class_list, interactive=True) elif val=="Order": return gr.Dropdown(label="Name", choices=order_list, interactive=True) elif val=="Family": return gr.Dropdown(label="Name", choices=family_list, interactive=True) elif val=="Genus": return gr.Dropdown(label="Name", choices=genus_list, interactive=True) elif val=="Species": return gr.Dropdown(label="Name", choices=species_list, interactive=True) def text_fn(taxon, name): global pred_global if taxon=="Class": text_embeds = clas[()][name] elif taxon=="Order": text_embeds = order[()][name] elif taxon=="Family": text_embeds = family[()][name] elif taxon=="Genus": text_embeds = genus[()][name] elif taxon=="Species": text_embeds = species[()][name] text_embeds = torch.tensor(text_embeds) preds = model(text_embeds).sigmoid().squeeze(0).squeeze(0).detach().numpy() pred_global = preds cmap = plt.get_cmap('Greens') rgba_img = cmap(preds) rgb_img = np.delete(rgba_img, 3, 2) #return gr.Image(preds, label="Predicted Heatmap", visible=True) return rgb_img def thresh_fn(val): global pred_global preds = deepcopy(pred_global) preds[preds=val] = 1 cmap = plt.get_cmap('Greens') rgba_img = cmap(preds) rgb_img = np.delete(rgba_img, 3, 2) return rgb_img with gr.Blocks() as demo: gr.Markdown( """ # Hierarchical Species Distribution Model! This model predicts the distribution of species based on geographic, environmental, and natural language features. """) with gr.Row(): inp = gr.Dropdown(label="Taxonomic Hierarchy", choices=["Class", "Order", "Family", "Genus", "Species"]) out = gr.Dropdown(label="Name", interactive=True) inp.change(update_fn, inp, out) with gr.Row(): check_button = gr.Button("Run Model") with gr.Row(): slider = gr.Slider(minimum=0, maximum=1, step=0.01, default=0.5, label="Confidence Threshold") with gr.Row(): pred = gr.Image(label="Predicted Heatmap", visible=True) check_button.click(text_fn, inputs=[inp, out], outputs=[pred]) slider.change(thresh_fn, slider, outputs=pred) demo.launch()