File size: 4,933 Bytes
a83d785 59e21dc 44c57c9 2bebf28 44c57c9 e1eafea dfcabe0 c7a6586 59e21dc 806fa12 59e21dc a83d785 43b8920 59e21dc a83d785 9ae1458 31833b2 e3de1e6 9ae1458 fa28aab 44c57c9 fa28aab 9ea5d8f 89a25ce 4d5623e 59e21dc 4d5623e 2c3470d 9ae1458 2c3470d 9ae1458 2c3470d fa28aab e3de1e6 fa28aab e3de1e6 2c3470d afb2a6a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
import gradio as gr
import numpy as np
import torch
import torch.nn as nn
from einops import rearrange
class Attn(nn.Module):
def __init__(self, dim, dim_text, heads = 16, dim_head = 64):
super().__init__()
self.scale = dim_head ** -0.5
self.heads = heads
hidden_dim = dim_head * heads
self.to_q = nn.Conv2d(dim, hidden_dim, 4, bias = False, stride=4)
self.to_kv = nn.Linear(dim_text, hidden_dim * 2, bias=False)
#self.norm = nn.LayerNorm(dim)
self.to_out = nn.Linear(hidden_dim, dim)
def forward(self, x, text):
b, c, h, w = x.shape
kv = self.to_kv(text).chunk(2, dim = -1)
k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), kv)
q = self.to_q(x)
q = rearrange(q, 'b (h c) x y -> b h (x y) c', h=self.heads)
#attn = torch.einsum('bhnd,bhed->bhne',q,k) * self.scale
attn = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = attn.softmax(dim=-1)
#print(attn.shape)
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
#print(out.shape)
return self.to_out(out)
class RangeModel(nn.Module):
def __init__(self):
super(RangeModel, self).__init__()
self.cross_attn = Attn(128, 8192)
self.upsample = nn.Upsample(scale_factor=4, mode='bilinear')
self.out = nn.Conv2d(128, 1, 1, bias=False)
self.x = None
def forward(self, text):
x = self.cross_attn(self.x, text)
x = rearrange(x, 'b (h w) d -> b d h w', h=225)
x = self.upsample(x)
x = self.out(x)
return x
model = RangeModel()
model.load_state_dict(torch.load("model/demo_model.pt", map_location=torch.device('cpu')))
pos_embed = np.load("data/pos_embeds_model.npy", allow_pickle=True)
model.x = torch.tensor(pos_embed).float()
model.eval()
species = np.load("data/species_70b.npy", allow_pickle=True)
clas = np.load("data/class_70b.npy", allow_pickle=True)
order = np.load("data/order_70b.npy", allow_pickle=True)
#genus = np.load("genus_70b.npy")
#family = np.load("family_70b.npy")
species_list = list(species[()].keys())
class_list = list(clas[()].keys())
order_list = list(order[()].keys())
#genus_list = list(genus[()].keys())
#family_list = list(family[()].keys())
def update_fn(val):
if val=="Class":
return gr.Dropdown(label="Name", choices=class_list, interactive=True)
elif val=="Order":
return gr.Dropdown(label="Name", choices=order_list, interactive=True)
elif val=="Family":
return gr.Dropdown(label="Name", choices=family_list, interactive=True)
elif val=="Genus":
return gr.Dropdown(label="Name", choices=genus_list, interactive=True)
elif val=="Species":
return gr.Dropdown(label="Name", choices=species_list, interactive=True)
def text_fn(taxon, name):
if taxon=="Class":
text_embeds = clas[()][name]
elif taxon=="Order":
text_embeds = order[()][name]
elif taxon=="Family":
text_embeds = family[()][name]
elif taxon=="Genus":
text_embeds = genus[()][name]
elif taxon=="Species":
text_embeds = species[()][name]
text_embeds = torch.tensor(text_embeds)
preds = model(text_embeds).sigmoid().unsqueeze(0).unsqueeze(0).detach().numpy()
#return gr.Image(preds, label="Predicted Heatmap", visible=True)
return taxon + ": " + name + ": " + str(np.mean(preds)), preds
def pred_fn(taxon, name):
if taxon=="Class":
text_embeds = clas[()][name]
elif taxon=="Order":
text_embeds = order[()][name]
elif taxon=="Family":
text_embeds = family[()][name]
elif taxon=="Genus":
text_embeds = genus[()][name]
elif taxon=="Species":
text_embeds = species[()][name]
text_embeds = torch.tensor(text_embeds)
preds = model(text_embeds).sigmoid().unsqueeze(0).unsqueeze(0).detach().numpy()
return gr.Image(preds, label="Predicted Heatmap", visible=True)
with gr.Blocks() as demo:
gr.Markdown(
"""
# Hierarchical Species Distribution Model!
This model predicts the distribution of species based on geographic, environmental, and natural language features.
""")
with gr.Row():
inp = gr.Dropdown(label="Taxonomic Hierarchy", choices=["Class", "Order", "Family", "Genus", "Species"])
out = gr.Dropdown(label="Name", interactive=True)
text = gr.Textbox(label="Text", visible=True, interactive=True)
inp.change(update_fn, inp, out)
with gr.Row():
check_button = gr.Button("Check")
submit_button = gr.Button("Run Model")
with gr.Row():
pred = gr.Image(label="Predicted Heatmap", visible=True)
check_button.click(text_fn, inputs=[inp, out], outputs=[text, pred])
#submit_button.click(pred_fn, inputs=[inp, out], outputs=[pred])
demo.launch() |