File size: 4,933 Bytes
a83d785
59e21dc
44c57c9
 
 
 
 
2bebf28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44c57c9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e1eafea
dfcabe0
c7a6586
59e21dc
806fa12
 
 
59e21dc
 
 
 
 
 
 
 
a83d785
43b8920
59e21dc
 
 
 
 
 
 
 
 
 
a83d785
9ae1458
31833b2
 
 
 
 
 
 
 
 
 
 
 
 
 
e3de1e6
9ae1458
fa28aab
 
 
 
 
 
 
 
 
 
 
44c57c9
 
 
 
fa28aab
 
9ea5d8f
89a25ce
4d5623e
59e21dc
 
4d5623e
2c3470d
 
 
9ae1458
2c3470d
 
 
9ae1458
2c3470d
 
fa28aab
e3de1e6
fa28aab
e3de1e6
 
2c3470d
afb2a6a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import gradio as gr
import numpy as np
import torch
import torch.nn as nn
from einops import rearrange


class Attn(nn.Module):
    def __init__(self, dim, dim_text, heads = 16, dim_head = 64):
        super().__init__()
        self.scale = dim_head ** -0.5
        self.heads = heads
        hidden_dim = dim_head * heads
        self.to_q = nn.Conv2d(dim, hidden_dim, 4, bias = False, stride=4)
        self.to_kv = nn.Linear(dim_text, hidden_dim * 2, bias=False)
        #self.norm = nn.LayerNorm(dim)
        self.to_out = nn.Linear(hidden_dim, dim)

    def forward(self, x, text):
        b, c, h, w = x.shape
        kv = self.to_kv(text).chunk(2, dim = -1)
        k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), kv)
        q = self.to_q(x)
        q = rearrange(q, 'b (h c) x y -> b h (x y) c', h=self.heads)

        #attn = torch.einsum('bhnd,bhed->bhne',q,k) * self.scale
        attn = torch.matmul(q, k.transpose(-1, -2)) * self.scale
        attn = attn.softmax(dim=-1)
        #print(attn.shape)
        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        #print(out.shape)
        return self.to_out(out)

class RangeModel(nn.Module):
    def __init__(self):
        super(RangeModel, self).__init__()
        self.cross_attn = Attn(128, 8192)
        self.upsample = nn.Upsample(scale_factor=4, mode='bilinear')
        self.out = nn.Conv2d(128, 1, 1, bias=False)
        self.x = None
    
    def forward(self, text):
        x = self.cross_attn(self.x, text)
        x = rearrange(x, 'b (h w) d -> b d h w', h=225)
        x = self.upsample(x)
        x = self.out(x)
        return x

model = RangeModel()
model.load_state_dict(torch.load("model/demo_model.pt", map_location=torch.device('cpu')))
pos_embed = np.load("data/pos_embeds_model.npy", allow_pickle=True)
model.x = torch.tensor(pos_embed).float()
model.eval()

species = np.load("data/species_70b.npy", allow_pickle=True)
clas = np.load("data/class_70b.npy", allow_pickle=True)
order = np.load("data/order_70b.npy", allow_pickle=True)
#genus = np.load("genus_70b.npy")
#family = np.load("family_70b.npy")

species_list = list(species[()].keys())
class_list = list(clas[()].keys())
order_list = list(order[()].keys())
#genus_list = list(genus[()].keys())
#family_list = list(family[()].keys())

def update_fn(val):
    if val=="Class":
        return gr.Dropdown(label="Name", choices=class_list, interactive=True)
    elif val=="Order":
        return gr.Dropdown(label="Name", choices=order_list, interactive=True)
    elif val=="Family":
        return gr.Dropdown(label="Name", choices=family_list, interactive=True)
    elif val=="Genus":
        return gr.Dropdown(label="Name", choices=genus_list, interactive=True)
    elif val=="Species":
        return gr.Dropdown(label="Name", choices=species_list, interactive=True)

def text_fn(taxon, name):
    if taxon=="Class":
        text_embeds = clas[()][name]
    elif taxon=="Order":
        text_embeds = order[()][name]
    elif taxon=="Family":
        text_embeds = family[()][name]
    elif taxon=="Genus":
        text_embeds = genus[()][name]
    elif taxon=="Species":
        text_embeds = species[()][name]
    
    text_embeds = torch.tensor(text_embeds)
    preds = model(text_embeds).sigmoid().unsqueeze(0).unsqueeze(0).detach().numpy()
    #return gr.Image(preds, label="Predicted Heatmap", visible=True)
    return taxon + ": " + name + ": " + str(np.mean(preds)), preds

def pred_fn(taxon, name):
    if taxon=="Class":
        text_embeds = clas[()][name]
    elif taxon=="Order":
        text_embeds = order[()][name]
    elif taxon=="Family":
        text_embeds = family[()][name]
    elif taxon=="Genus":
        text_embeds = genus[()][name]
    elif taxon=="Species":
        text_embeds = species[()][name]
    
    text_embeds = torch.tensor(text_embeds)
    preds = model(text_embeds).sigmoid().unsqueeze(0).unsqueeze(0).detach().numpy()
    return gr.Image(preds, label="Predicted Heatmap", visible=True)


with gr.Blocks() as demo:
    gr.Markdown(
    """
    # Hierarchical Species Distribution Model!
    This model predicts the distribution of species based on geographic, environmental, and natural language features.
    """)
    with gr.Row():
        inp = gr.Dropdown(label="Taxonomic Hierarchy", choices=["Class", "Order", "Family", "Genus", "Species"])
        out = gr.Dropdown(label="Name", interactive=True)
        text = gr.Textbox(label="Text", visible=True, interactive=True)
        inp.change(update_fn, inp, out)
    
    with gr.Row():
        check_button = gr.Button("Check")
        submit_button = gr.Button("Run Model")
    
    with gr.Row():
        pred = gr.Image(label="Predicted Heatmap", visible=True)
    
    check_button.click(text_fn, inputs=[inp, out], outputs=[text, pred])
    #submit_button.click(pred_fn, inputs=[inp, out], outputs=[pred])
    
demo.launch()