Vishu26's picture
data
9d3003c
import gradio as gr
import numpy as np
import torch
import torch.nn as nn
from einops import rearrange
import matplotlib.pyplot as plt
from copy import deepcopy
import rasterio
from rasterio.enums import Resampling
pred_global = None
land_mask = (rasterio.open('data/LAND_MASK.tif').read(out_shape=(1, 900, 1800), resampling=Resampling.nearest) == 1).squeeze(0)
class Attn(nn.Module):
def __init__(self, dim, dim_text, heads = 16, dim_head = 64):
super().__init__()
self.scale = dim_head ** -0.5
self.heads = heads
hidden_dim = dim_head * heads
self.to_q = nn.Conv2d(dim, hidden_dim, 4, bias = False, stride=4)
self.to_kv = nn.Linear(dim_text, hidden_dim * 2, bias=False)
#self.norm = nn.LayerNorm(dim)
self.to_out = nn.Linear(hidden_dim, dim)
def forward(self, x, text):
b, c, h, w = x.shape
kv = self.to_kv(text).chunk(2, dim = -1)
k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), kv)
q = self.to_q(x)
q = rearrange(q, 'b (h c) x y -> b h (x y) c', h=self.heads)
#attn = torch.einsum('bhnd,bhed->bhne',q,k) * self.scale
attn = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = attn.softmax(dim=-1)
#print(attn.shape)
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
#print(out.shape)
return self.to_out(out)
class RangeModel(nn.Module):
def __init__(self):
super(RangeModel, self).__init__()
self.cross_attn = Attn(128, 8192)
self.upsample = nn.Upsample(scale_factor=4, mode='bilinear')
self.out = nn.Conv2d(128, 1, 1, bias=False)
self.x = None
def forward(self, text):
x = self.cross_attn(self.x, text)
x = rearrange(x, 'b (h w) d -> b d h w', h=225)
x = self.upsample(x)
x = self.out(x)
return x
model = RangeModel()
model.load_state_dict(torch.load("model/demo_model.pt", map_location=torch.device('cpu')))
pos_embed = np.load("data/pos_embeds_model.npy", allow_pickle=True)
model.x = torch.tensor(pos_embed).float()
model.eval()
species = np.load("data/species_70b.npy", allow_pickle=True)
clas = np.load("data/class_70b.npy", allow_pickle=True)
order = np.load("data/order_70b.npy", allow_pickle=True)
#genus = np.load("genus_70b.npy")
family = np.load("data/family_70b.npy", allow_pickle=True)
species_list = list(species[()].keys())
class_list = list(clas[()].keys())
order_list = list(order[()].keys())
#genus_list = list(genus[()].keys())
family_list = list(family[()].keys())
def update_fn(val):
if val=="Class":
return gr.Dropdown(label="Name", choices=class_list, interactive=True)
elif val=="Order":
return gr.Dropdown(label="Name", choices=order_list, interactive=True)
elif val=="Family":
return gr.Dropdown(label="Name", choices=family_list, interactive=True)
elif val=="Genus":
return gr.Dropdown(label="Name", choices=genus_list, interactive=True)
elif val=="Species":
return gr.Dropdown(label="Name", choices=species_list, interactive=True)
def text_fn(taxon, name):
global pred_global
if taxon=="Class":
text_embeds = clas[()][name]
elif taxon=="Order":
text_embeds = order[()][name]
elif taxon=="Family":
text_embeds = family[()][name]
elif taxon=="Genus":
text_embeds = genus[()][name]
elif taxon=="Species":
text_embeds = species[()][name]
text_embeds = torch.tensor(text_embeds)
preds = model(text_embeds).sigmoid().squeeze(0).squeeze(0).detach().numpy() * land_mask
pred_global = preds
cmap = plt.get_cmap('plasma')
rgba_img = cmap(preds)
rgb_img = np.delete(rgba_img, 3, 2)
#return gr.Image(preds, label="Predicted Heatmap", visible=True)
return rgb_img
def thresh_fn(val):
global pred_global
preds = deepcopy(pred_global)
preds[preds<val] = 0
preds[preds>=val] = 1
cmap = plt.get_cmap('plasma')
rgba_img = cmap(preds)
rgb_img = np.delete(rgba_img, 3, 2)
return rgb_img
with gr.Blocks() as demo:
gr.Markdown(
"""
# Hierarchical Species Distribution Model!
This model predicts the distribution of species based on geographic, environmental, and natural language features.
""")
with gr.Row():
inp = gr.Dropdown(label="Taxonomic Hierarchy", choices=["Class", "Order", "Family", "Genus", "Species"])
out = gr.Dropdown(label="Name", interactive=True)
inp.change(update_fn, inp, out)
with gr.Row():
check_button = gr.Button("Run Model")
with gr.Row():
slider = gr.Slider(minimum=0, maximum=1, step=0.01, default=0.5, label="Confidence Threshold")
with gr.Row():
pred = gr.Image(label="Predicted Heatmap", visible=True)
check_button.click(text_fn, inputs=[inp, out], outputs=[pred])
slider.change(thresh_fn, slider, outputs=pred)
demo.launch()