Vishu26's picture
data
c7a6586
raw
history blame
4.93 kB
import gradio as gr
import numpy as np
import torch
import torch.nn as nn
from einops import rearrange
class Attn(nn.Module):
def __init__(self, dim, dim_text, heads = 16, dim_head = 64):
super().__init__()
self.scale = dim_head ** -0.5
self.heads = heads
hidden_dim = dim_head * heads
self.to_q = nn.Conv2d(dim, hidden_dim, 4, bias = False, stride=4)
self.to_kv = nn.Linear(dim_text, hidden_dim * 2, bias=False)
#self.norm = nn.LayerNorm(dim)
self.to_out = nn.Linear(hidden_dim, dim)
def forward(self, x, text):
b, c, h, w = x.shape
kv = self.to_kv(text).chunk(2, dim = -1)
k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), kv)
q = self.to_q(x)
q = rearrange(q, 'b (h c) x y -> b h (x y) c', h=self.heads)
#attn = torch.einsum('bhnd,bhed->bhne',q,k) * self.scale
attn = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = attn.softmax(dim=-1)
#print(attn.shape)
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
#print(out.shape)
return self.to_out(out)
class RangeModel(nn.Module):
def __init__(self):
super(RangeModel, self).__init__()
self.cross_attn = Attn(128, 8192)
self.upsample = nn.Upsample(scale_factor=4, mode='bilinear')
self.out = nn.Conv2d(128, 1, 1, bias=False)
self.x = None
def forward(self, text):
x = self.cross_attn(self.x, text)
x = rearrange(x, 'b (h w) d -> b d h w', h=225)
x = self.upsample(x)
x = self.out(x)
return x
model = RangeModel()
model.load_state_dict(torch.load("model/demo_model.pt", map_location=torch.device('cpu')))
pos_embed = np.load("data/pos_embed.npy", allow_pickle=True)
model.x = torch.tensor(pos_embed).unsqueeze(0).float()
model.eval()
species = np.load("data/species_70b.npy", allow_pickle=True)
clas = np.load("data/class_70b.npy", allow_pickle=True)
order = np.load("data/order_70b.npy", allow_pickle=True)
#genus = np.load("genus_70b.npy")
#family = np.load("family_70b.npy")
species_list = list(species[()].keys())
class_list = list(clas[()].keys())
order_list = list(order[()].keys())
#genus_list = list(genus[()].keys())
#family_list = list(family[()].keys())
def update_fn(val):
if val=="Class":
return gr.Dropdown(label="Name", choices=class_list, interactive=True)
elif val=="Order":
return gr.Dropdown(label="Name", choices=order_list, interactive=True)
elif val=="Family":
return gr.Dropdown(label="Name", choices=family_list, interactive=True)
elif val=="Genus":
return gr.Dropdown(label="Name", choices=genus_list, interactive=True)
elif val=="Species":
return gr.Dropdown(label="Name", choices=species_list, interactive=True)
def text_fn(taxon, name):
if taxon=="Class":
text_embeds = clas[()][name]
elif taxon=="Order":
text_embeds = order[()][name]
elif taxon=="Family":
text_embeds = family[()][name]
elif taxon=="Genus":
text_embeds = genus[()][name]
elif taxon=="Species":
text_embeds = species[()][name]
text_embeds = torch.tensor(text_embeds)
preds = model(text_embeds).sigmoid().unsqueeze(0).unsqueeze(0).detach().numpy()
#return gr.Image(preds, label="Predicted Heatmap", visible=True)
return taxon + ": " + name + ": " + str(np.mean(preds))
def pred_fn(taxon, name):
if taxon=="Class":
text_embeds = clas[()][name]
elif taxon=="Order":
text_embeds = order[()][name]
elif taxon=="Family":
text_embeds = family[()][name]
elif taxon=="Genus":
text_embeds = genus[()][name]
elif taxon=="Species":
text_embeds = species[()][name]
text_embeds = torch.tensor(text_embeds)
preds = model(text_embeds).sigmoid().unsqueeze(0).unsqueeze(0).detach().numpy()
return gr.Image(preds, label="Predicted Heatmap", visible=True)
with gr.Blocks() as demo:
gr.Markdown(
"""
# Hierarchical Species Distribution Model!
This model predicts the distribution of species based on geographic, environmental, and natural language features.
""")
with gr.Row():
inp = gr.Dropdown(label="Taxonomic Hierarchy", choices=["Class", "Order", "Family", "Genus", "Species"])
out = gr.Dropdown(label="Name", interactive=True)
text = gr.Textbox(label="Text", visible=True, interactive=True)
inp.change(update_fn, inp, out)
with gr.Row():
check_button = gr.Button("Check")
submit_button = gr.Button("Run Model")
with gr.Row():
pred = gr.Image(label="Predicted Heatmap", visible=False)
check_button.click(text_fn, inputs=[inp, out], outputs=[text])
submit_button.click(pred_fn, inputs=[inp, out], outputs=[pred])
demo.launch()