|
import gradio as gr |
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
from einops import rearrange |
|
import matplotlib.pyplot as plt |
|
from copy import deepcopy |
|
import rasterio |
|
from rasterio.enums import Resampling |
|
|
|
pred_global = None |
|
|
|
land_mask = (rasterio.open('data/LAND_MASK.tif').read(out_shape=(1, 900, 1800), resampling=Resampling.nearest) == 1).squeeze(0) |
|
|
|
class Attn(nn.Module): |
|
def __init__(self, dim, dim_text, heads = 16, dim_head = 64): |
|
super().__init__() |
|
self.scale = dim_head ** -0.5 |
|
self.heads = heads |
|
hidden_dim = dim_head * heads |
|
self.to_q = nn.Conv2d(dim, hidden_dim, 4, bias = False, stride=4) |
|
self.to_kv = nn.Linear(dim_text, hidden_dim * 2, bias=False) |
|
|
|
self.to_out = nn.Linear(hidden_dim, dim) |
|
|
|
def forward(self, x, text): |
|
b, c, h, w = x.shape |
|
kv = self.to_kv(text).chunk(2, dim = -1) |
|
k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), kv) |
|
q = self.to_q(x) |
|
q = rearrange(q, 'b (h c) x y -> b h (x y) c', h=self.heads) |
|
|
|
|
|
attn = torch.matmul(q, k.transpose(-1, -2)) * self.scale |
|
attn = attn.softmax(dim=-1) |
|
|
|
out = torch.matmul(attn, v) |
|
out = rearrange(out, 'b h n d -> b n (h d)') |
|
|
|
return self.to_out(out) |
|
|
|
class RangeModel(nn.Module): |
|
def __init__(self): |
|
super(RangeModel, self).__init__() |
|
self.cross_attn = Attn(128, 8192) |
|
self.upsample = nn.Upsample(scale_factor=4, mode='bilinear') |
|
self.out = nn.Conv2d(128, 1, 1, bias=False) |
|
self.x = None |
|
|
|
def forward(self, text): |
|
x = self.cross_attn(self.x, text) |
|
x = rearrange(x, 'b (h w) d -> b d h w', h=225) |
|
x = self.upsample(x) |
|
x = self.out(x) |
|
return x |
|
|
|
model = RangeModel() |
|
model.load_state_dict(torch.load("model/demo_model.pt", map_location=torch.device('cpu'))) |
|
pos_embed = np.load("data/pos_embeds_model.npy", allow_pickle=True) |
|
model.x = torch.tensor(pos_embed).float() |
|
model.eval() |
|
|
|
species = np.load("data/species_70b.npy", allow_pickle=True) |
|
clas = np.load("data/class_70b.npy", allow_pickle=True) |
|
order = np.load("data/order_70b.npy", allow_pickle=True) |
|
|
|
family = np.load("data/family_70b.npy", allow_pickle=True) |
|
|
|
species_list = list(species[()].keys()) |
|
class_list = list(clas[()].keys()) |
|
order_list = list(order[()].keys()) |
|
|
|
family_list = list(family[()].keys()) |
|
|
|
def update_fn(val): |
|
if val=="Class": |
|
return gr.Dropdown(label="Name", choices=class_list, interactive=True) |
|
elif val=="Order": |
|
return gr.Dropdown(label="Name", choices=order_list, interactive=True) |
|
elif val=="Family": |
|
return gr.Dropdown(label="Name", choices=family_list, interactive=True) |
|
elif val=="Genus": |
|
return gr.Dropdown(label="Name", choices=genus_list, interactive=True) |
|
elif val=="Species": |
|
return gr.Dropdown(label="Name", choices=species_list, interactive=True) |
|
|
|
def text_fn(taxon, name): |
|
global pred_global |
|
if taxon=="Class": |
|
text_embeds = clas[()][name] |
|
elif taxon=="Order": |
|
text_embeds = order[()][name] |
|
elif taxon=="Family": |
|
text_embeds = family[()][name] |
|
elif taxon=="Genus": |
|
text_embeds = genus[()][name] |
|
elif taxon=="Species": |
|
text_embeds = species[()][name] |
|
|
|
text_embeds = torch.tensor(text_embeds) |
|
preds = model(text_embeds).sigmoid().squeeze(0).squeeze(0).detach().numpy() * land_mask |
|
pred_global = preds |
|
cmap = plt.get_cmap('plasma') |
|
|
|
rgba_img = cmap(preds) |
|
rgb_img = np.delete(rgba_img, 3, 2) |
|
|
|
return rgb_img |
|
|
|
def thresh_fn(val): |
|
global pred_global |
|
preds = deepcopy(pred_global) |
|
preds[preds<val] = 0 |
|
preds[preds>=val] = 1 |
|
cmap = plt.get_cmap('plasma') |
|
|
|
rgba_img = cmap(preds) |
|
rgb_img = np.delete(rgba_img, 3, 2) |
|
return rgb_img |
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown( |
|
""" |
|
# Hierarchical Species Distribution Model! |
|
This model predicts the distribution of species based on geographic, environmental, and natural language features. |
|
""") |
|
with gr.Row(): |
|
inp = gr.Dropdown(label="Taxonomic Hierarchy", choices=["Class", "Order", "Family", "Genus", "Species"]) |
|
out = gr.Dropdown(label="Name", interactive=True) |
|
inp.change(update_fn, inp, out) |
|
|
|
with gr.Row(): |
|
check_button = gr.Button("Run Model") |
|
|
|
with gr.Row(): |
|
slider = gr.Slider(minimum=0, maximum=1, step=0.01, default=0.5, label="Confidence Threshold") |
|
|
|
with gr.Row(): |
|
pred = gr.Image(label="Predicted Heatmap", visible=True) |
|
|
|
check_button.click(text_fn, inputs=[inp, out], outputs=[pred]) |
|
slider.change(thresh_fn, slider, outputs=pred) |
|
|
|
demo.launch() |