File size: 5,015 Bytes
a83d785
59e21dc
44c57c9
 
 
41180d8
34da2e5
b5844ee
 
44c57c9
435a5b8
44c57c9
b5844ee
 
2bebf28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44c57c9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e1eafea
dfcabe0
c7a6586
59e21dc
806fa12
 
 
59e21dc
3c43ea2
59e21dc
 
 
 
 
c3d7938
a83d785
43b8920
59e21dc
 
 
 
 
 
 
 
 
 
a83d785
9ae1458
435a5b8
31833b2
 
 
 
 
 
 
 
 
 
 
 
b5844ee
435a5b8
9d3003c
41180d8
 
 
31833b2
3f0099e
9ae1458
435a5b8
 
34da2e5
2363ad4
 
9d3003c
fa28aab
34da2e5
435a5b8
4d54978
fa28aab
9ea5d8f
89a25ce
4d5623e
59e21dc
 
4d5623e
2c3470d
 
 
 
 
 
3f0099e
 
 
 
2c3470d
fa28aab
e3de1e6
fa28aab
3f0099e
4d54978
2c3470d
afb2a6a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import gradio as gr
import numpy as np
import torch
import torch.nn as nn
from einops import rearrange
import matplotlib.pyplot as plt
from copy import deepcopy
import rasterio
from rasterio.enums import Resampling

pred_global = None

land_mask = (rasterio.open('data/LAND_MASK.tif').read(out_shape=(1, 900, 1800), resampling=Resampling.nearest) == 1).squeeze(0)

class Attn(nn.Module):
    def __init__(self, dim, dim_text, heads = 16, dim_head = 64):
        super().__init__()
        self.scale = dim_head ** -0.5
        self.heads = heads
        hidden_dim = dim_head * heads
        self.to_q = nn.Conv2d(dim, hidden_dim, 4, bias = False, stride=4)
        self.to_kv = nn.Linear(dim_text, hidden_dim * 2, bias=False)
        #self.norm = nn.LayerNorm(dim)
        self.to_out = nn.Linear(hidden_dim, dim)

    def forward(self, x, text):
        b, c, h, w = x.shape
        kv = self.to_kv(text).chunk(2, dim = -1)
        k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), kv)
        q = self.to_q(x)
        q = rearrange(q, 'b (h c) x y -> b h (x y) c', h=self.heads)

        #attn = torch.einsum('bhnd,bhed->bhne',q,k) * self.scale
        attn = torch.matmul(q, k.transpose(-1, -2)) * self.scale
        attn = attn.softmax(dim=-1)
        #print(attn.shape)
        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        #print(out.shape)
        return self.to_out(out)

class RangeModel(nn.Module):
    def __init__(self):
        super(RangeModel, self).__init__()
        self.cross_attn = Attn(128, 8192)
        self.upsample = nn.Upsample(scale_factor=4, mode='bilinear')
        self.out = nn.Conv2d(128, 1, 1, bias=False)
        self.x = None
    
    def forward(self, text):
        x = self.cross_attn(self.x, text)
        x = rearrange(x, 'b (h w) d -> b d h w', h=225)
        x = self.upsample(x)
        x = self.out(x)
        return x

model = RangeModel()
model.load_state_dict(torch.load("model/demo_model.pt", map_location=torch.device('cpu')))
pos_embed = np.load("data/pos_embeds_model.npy", allow_pickle=True)
model.x = torch.tensor(pos_embed).float()
model.eval()

species = np.load("data/species_70b.npy", allow_pickle=True)
clas = np.load("data/class_70b.npy", allow_pickle=True)
order = np.load("data/order_70b.npy", allow_pickle=True)
#genus = np.load("genus_70b.npy")
family = np.load("data/family_70b.npy", allow_pickle=True)

species_list = list(species[()].keys())
class_list = list(clas[()].keys())
order_list = list(order[()].keys())
#genus_list = list(genus[()].keys())
family_list = list(family[()].keys())

def update_fn(val):
    if val=="Class":
        return gr.Dropdown(label="Name", choices=class_list, interactive=True)
    elif val=="Order":
        return gr.Dropdown(label="Name", choices=order_list, interactive=True)
    elif val=="Family":
        return gr.Dropdown(label="Name", choices=family_list, interactive=True)
    elif val=="Genus":
        return gr.Dropdown(label="Name", choices=genus_list, interactive=True)
    elif val=="Species":
        return gr.Dropdown(label="Name", choices=species_list, interactive=True)

def text_fn(taxon, name):
    global pred_global
    if taxon=="Class":
        text_embeds = clas[()][name]
    elif taxon=="Order":
        text_embeds = order[()][name]
    elif taxon=="Family":
        text_embeds = family[()][name]
    elif taxon=="Genus":
        text_embeds = genus[()][name]
    elif taxon=="Species":
        text_embeds = species[()][name]
    
    text_embeds = torch.tensor(text_embeds)
    preds = model(text_embeds).sigmoid().squeeze(0).squeeze(0).detach().numpy() * land_mask
    pred_global = preds
    cmap = plt.get_cmap('plasma')

    rgba_img = cmap(preds)
    rgb_img = np.delete(rgba_img, 3, 2)
    #return gr.Image(preds, label="Predicted Heatmap", visible=True)
    return rgb_img

def thresh_fn(val):
    global pred_global
    preds = deepcopy(pred_global)
    preds[preds<val] = 0
    preds[preds>=val] = 1
    cmap = plt.get_cmap('plasma')

    rgba_img = cmap(preds)
    rgb_img = np.delete(rgba_img, 3, 2)
    return rgb_img

with gr.Blocks() as demo:
    gr.Markdown(
    """
    # Hierarchical Species Distribution Model!
    This model predicts the distribution of species based on geographic, environmental, and natural language features.
    """)
    with gr.Row():
        inp = gr.Dropdown(label="Taxonomic Hierarchy", choices=["Class", "Order", "Family", "Genus", "Species"])
        out = gr.Dropdown(label="Name", interactive=True)
        inp.change(update_fn, inp, out)
    
    with gr.Row():
        check_button = gr.Button("Run Model")
    
    with gr.Row():
        slider = gr.Slider(minimum=0, maximum=1, step=0.01, default=0.5, label="Confidence Threshold")
    
    with gr.Row():
        pred = gr.Image(label="Predicted Heatmap", visible=True)
    
    check_button.click(text_fn, inputs=[inp, out], outputs=[pred])
    slider.change(thresh_fn, slider, outputs=pred)
    
demo.launch()