import gradio as gr import numpy as np import torch import torch.nn as nn from einops import rearrange import matplotlib.pyplot as plt from copy import deepcopy import rasterio from rasterio.enums import Resampling pred_global = None land_mask = (rasterio.open('data/LAND_MASK.tif').read(out_shape=(1, 900, 1800), resampling=Resampling.nearest) == 1).squeeze(0) class Attn(nn.Module): def __init__(self, dim, dim_text, heads = 16, dim_head = 64): super().__init__() self.scale = dim_head ** -0.5 self.heads = heads hidden_dim = dim_head * heads self.to_q = nn.Conv2d(dim, hidden_dim, 4, bias = False, stride=4) self.to_kv = nn.Linear(dim_text, hidden_dim * 2, bias=False) #self.norm = nn.LayerNorm(dim) self.to_out = nn.Linear(hidden_dim, dim) def forward(self, x, text): b, c, h, w = x.shape kv = self.to_kv(text).chunk(2, dim = -1) k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), kv) q = self.to_q(x) q = rearrange(q, 'b (h c) x y -> b h (x y) c', h=self.heads) #attn = torch.einsum('bhnd,bhed->bhne',q,k) * self.scale attn = torch.matmul(q, k.transpose(-1, -2)) * self.scale attn = attn.softmax(dim=-1) #print(attn.shape) out = torch.matmul(attn, v) out = rearrange(out, 'b h n d -> b n (h d)') #print(out.shape) return self.to_out(out) class RangeModel(nn.Module): def __init__(self): super(RangeModel, self).__init__() self.cross_attn = Attn(128, 8192) self.upsample = nn.Upsample(scale_factor=4, mode='bilinear') self.out = nn.Conv2d(128, 1, 1, bias=False) self.x = None def forward(self, text): x = self.cross_attn(self.x, text) x = rearrange(x, 'b (h w) d -> b d h w', h=225) x = self.upsample(x) x = self.out(x) return x model = RangeModel() model.load_state_dict(torch.load("model/demo_model.pt", map_location=torch.device('cpu'))) pos_embed = np.load("data/pos_embeds_model.npy", allow_pickle=True) model.x = torch.tensor(pos_embed).float() model.eval() species = np.load("data/species_70b.npy", allow_pickle=True) clas = np.load("data/class_70b.npy", allow_pickle=True) order = np.load("data/order_70b.npy", allow_pickle=True) #genus = np.load("genus_70b.npy") family = np.load("data/family_70b.npy", allow_pickle=True) species_list = list(species[()].keys()) class_list = list(clas[()].keys()) order_list = list(order[()].keys()) #genus_list = list(genus[()].keys()) family_list = list(family[()].keys()) def update_fn(val): if val=="Class": return gr.Dropdown(label="Name", choices=class_list, interactive=True) elif val=="Order": return gr.Dropdown(label="Name", choices=order_list, interactive=True) elif val=="Family": return gr.Dropdown(label="Name", choices=family_list, interactive=True) elif val=="Genus": return gr.Dropdown(label="Name", choices=genus_list, interactive=True) elif val=="Species": return gr.Dropdown(label="Name", choices=species_list, interactive=True) def text_fn(taxon, name): global pred_global if taxon=="Class": text_embeds = clas[()][name] elif taxon=="Order": text_embeds = order[()][name] elif taxon=="Family": text_embeds = family[()][name] elif taxon=="Genus": text_embeds = genus[()][name] elif taxon=="Species": text_embeds = species[()][name] text_embeds = torch.tensor(text_embeds) preds = model(text_embeds).sigmoid().squeeze(0).squeeze(0).detach().numpy() * land_mask pred_global = preds cmap = plt.get_cmap('plasma') rgba_img = cmap(preds) rgb_img = np.delete(rgba_img, 3, 2) #return gr.Image(preds, label="Predicted Heatmap", visible=True) return rgb_img def thresh_fn(val): global pred_global preds = deepcopy(pred_global) preds[preds=val] = 1 cmap = plt.get_cmap('plasma') rgba_img = cmap(preds) rgb_img = np.delete(rgba_img, 3, 2) return rgb_img with gr.Blocks() as demo: gr.Markdown( """ # Hierarchical Species Distribution Model! This model predicts the distribution of species based on geographic, environmental, and natural language features. """) with gr.Row(): inp = gr.Dropdown(label="Taxonomic Hierarchy", choices=["Class", "Order", "Family", "Genus", "Species"]) out = gr.Dropdown(label="Name", interactive=True) inp.change(update_fn, inp, out) with gr.Row(): check_button = gr.Button("Run Model") with gr.Row(): slider = gr.Slider(minimum=0, maximum=1, step=0.01, default=0.5, label="Confidence Threshold") with gr.Row(): pred = gr.Image(label="Predicted Heatmap", visible=True) check_button.click(text_fn, inputs=[inp, out], outputs=[pred]) slider.change(thresh_fn, slider, outputs=pred) demo.launch()