File size: 5,015 Bytes
a83d785 59e21dc 44c57c9 41180d8 34da2e5 b5844ee 44c57c9 435a5b8 44c57c9 b5844ee 2bebf28 44c57c9 e1eafea dfcabe0 c7a6586 59e21dc 806fa12 59e21dc 3c43ea2 59e21dc c3d7938 a83d785 43b8920 59e21dc a83d785 9ae1458 435a5b8 31833b2 b5844ee 435a5b8 9d3003c 41180d8 31833b2 3f0099e 9ae1458 435a5b8 34da2e5 2363ad4 9d3003c fa28aab 34da2e5 435a5b8 4d54978 fa28aab 9ea5d8f 89a25ce 4d5623e 59e21dc 4d5623e 2c3470d 3f0099e 2c3470d fa28aab e3de1e6 fa28aab 3f0099e 4d54978 2c3470d afb2a6a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
import gradio as gr
import numpy as np
import torch
import torch.nn as nn
from einops import rearrange
import matplotlib.pyplot as plt
from copy import deepcopy
import rasterio
from rasterio.enums import Resampling
pred_global = None
land_mask = (rasterio.open('data/LAND_MASK.tif').read(out_shape=(1, 900, 1800), resampling=Resampling.nearest) == 1).squeeze(0)
class Attn(nn.Module):
def __init__(self, dim, dim_text, heads = 16, dim_head = 64):
super().__init__()
self.scale = dim_head ** -0.5
self.heads = heads
hidden_dim = dim_head * heads
self.to_q = nn.Conv2d(dim, hidden_dim, 4, bias = False, stride=4)
self.to_kv = nn.Linear(dim_text, hidden_dim * 2, bias=False)
#self.norm = nn.LayerNorm(dim)
self.to_out = nn.Linear(hidden_dim, dim)
def forward(self, x, text):
b, c, h, w = x.shape
kv = self.to_kv(text).chunk(2, dim = -1)
k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), kv)
q = self.to_q(x)
q = rearrange(q, 'b (h c) x y -> b h (x y) c', h=self.heads)
#attn = torch.einsum('bhnd,bhed->bhne',q,k) * self.scale
attn = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = attn.softmax(dim=-1)
#print(attn.shape)
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
#print(out.shape)
return self.to_out(out)
class RangeModel(nn.Module):
def __init__(self):
super(RangeModel, self).__init__()
self.cross_attn = Attn(128, 8192)
self.upsample = nn.Upsample(scale_factor=4, mode='bilinear')
self.out = nn.Conv2d(128, 1, 1, bias=False)
self.x = None
def forward(self, text):
x = self.cross_attn(self.x, text)
x = rearrange(x, 'b (h w) d -> b d h w', h=225)
x = self.upsample(x)
x = self.out(x)
return x
model = RangeModel()
model.load_state_dict(torch.load("model/demo_model.pt", map_location=torch.device('cpu')))
pos_embed = np.load("data/pos_embeds_model.npy", allow_pickle=True)
model.x = torch.tensor(pos_embed).float()
model.eval()
species = np.load("data/species_70b.npy", allow_pickle=True)
clas = np.load("data/class_70b.npy", allow_pickle=True)
order = np.load("data/order_70b.npy", allow_pickle=True)
#genus = np.load("genus_70b.npy")
family = np.load("data/family_70b.npy", allow_pickle=True)
species_list = list(species[()].keys())
class_list = list(clas[()].keys())
order_list = list(order[()].keys())
#genus_list = list(genus[()].keys())
family_list = list(family[()].keys())
def update_fn(val):
if val=="Class":
return gr.Dropdown(label="Name", choices=class_list, interactive=True)
elif val=="Order":
return gr.Dropdown(label="Name", choices=order_list, interactive=True)
elif val=="Family":
return gr.Dropdown(label="Name", choices=family_list, interactive=True)
elif val=="Genus":
return gr.Dropdown(label="Name", choices=genus_list, interactive=True)
elif val=="Species":
return gr.Dropdown(label="Name", choices=species_list, interactive=True)
def text_fn(taxon, name):
global pred_global
if taxon=="Class":
text_embeds = clas[()][name]
elif taxon=="Order":
text_embeds = order[()][name]
elif taxon=="Family":
text_embeds = family[()][name]
elif taxon=="Genus":
text_embeds = genus[()][name]
elif taxon=="Species":
text_embeds = species[()][name]
text_embeds = torch.tensor(text_embeds)
preds = model(text_embeds).sigmoid().squeeze(0).squeeze(0).detach().numpy() * land_mask
pred_global = preds
cmap = plt.get_cmap('plasma')
rgba_img = cmap(preds)
rgb_img = np.delete(rgba_img, 3, 2)
#return gr.Image(preds, label="Predicted Heatmap", visible=True)
return rgb_img
def thresh_fn(val):
global pred_global
preds = deepcopy(pred_global)
preds[preds<val] = 0
preds[preds>=val] = 1
cmap = plt.get_cmap('plasma')
rgba_img = cmap(preds)
rgb_img = np.delete(rgba_img, 3, 2)
return rgb_img
with gr.Blocks() as demo:
gr.Markdown(
"""
# Hierarchical Species Distribution Model!
This model predicts the distribution of species based on geographic, environmental, and natural language features.
""")
with gr.Row():
inp = gr.Dropdown(label="Taxonomic Hierarchy", choices=["Class", "Order", "Family", "Genus", "Species"])
out = gr.Dropdown(label="Name", interactive=True)
inp.change(update_fn, inp, out)
with gr.Row():
check_button = gr.Button("Run Model")
with gr.Row():
slider = gr.Slider(minimum=0, maximum=1, step=0.01, default=0.5, label="Confidence Threshold")
with gr.Row():
pred = gr.Image(label="Predicted Heatmap", visible=True)
check_button.click(text_fn, inputs=[inp, out], outputs=[pred])
slider.change(thresh_fn, slider, outputs=pred)
demo.launch() |