Vishu26 commited on
Commit
44c57c9
·
1 Parent(s): fa28aab
Files changed (1) hide show
  1. app.py +27 -3
app.py CHANGED
@@ -1,5 +1,27 @@
1
  import gradio as gr
2
  import numpy as np
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
  species = np.load("data/species_70b.npy", allow_pickle=True)
5
  clas = np.load("data/class_70b.npy", allow_pickle=True)
@@ -13,7 +35,7 @@ order_list = list(order[()].keys())
13
  #genus_list = list(genus[()].keys())
14
  #family_list = list(family[()].keys())
15
 
16
- pos_embed = np.load("data/pos_embed.npy", allow_pickle=True)
17
 
18
  def update_fn(val):
19
  if val=="Class":
@@ -38,7 +60,10 @@ def pred_fn(taxon, name):
38
  text_embeds = genus[()][name]
39
  elif taxon=="Species":
40
  text_embeds = species[()][name]
41
-
 
 
 
42
 
43
 
44
  with gr.Blocks() as demo:
@@ -60,5 +85,4 @@ with gr.Blocks() as demo:
60
 
61
  submit_button.click(pred_fn, inputs=[inp, out], outputs=[pred])
62
 
63
-
64
  demo.launch()
 
1
  import gradio as gr
2
  import numpy as np
3
+ import torch
4
+ import torch.nn as nn
5
+ from einops import rearrange
6
+
7
+
8
+ class RangeModel(nn.Module):
9
+ def __init__(self):
10
+ super(RangeModel, self).__init__()
11
+ self.cross_attn = Attn(128, 8192)
12
+ self.upsample = nn.Upsample(scale_factor=4, mode='bilinear')
13
+ self.out = nn.Conv2d(128, 1, 1, bias=False)
14
+ self.x = None
15
+
16
+ def forward(self, text):
17
+ x = self.cross_attn(self.x, text)
18
+ x = rearrange(x, 'b (h w) d -> b d h w', h=225)
19
+ x = self.upsample(x)
20
+ x = self.out(x)
21
+ return x
22
+
23
+ model = RangeModel()
24
+ model.load_state_dict(torch.load("model/demo_model.pt", map_location=torch.device('cpu')))
25
 
26
  species = np.load("data/species_70b.npy", allow_pickle=True)
27
  clas = np.load("data/class_70b.npy", allow_pickle=True)
 
35
  #genus_list = list(genus[()].keys())
36
  #family_list = list(family[()].keys())
37
 
38
+ #pos_embed = np.load("data/pos_embed.npy", allow_pickle=True)
39
 
40
  def update_fn(val):
41
  if val=="Class":
 
60
  text_embeds = genus[()][name]
61
  elif taxon=="Species":
62
  text_embeds = species[()][name]
63
+
64
+ text_embeds = torch.tensor(text_embeds)
65
+ preds = model(text_embeds).sigmoid().unsqueeze(0).unsqueeze(0).detach().numpy()
66
+ return gr.Image(preds, label="Predicted Heatmap", visible=True)
67
 
68
 
69
  with gr.Blocks() as demo:
 
85
 
86
  submit_button.click(pred_fn, inputs=[inp, out], outputs=[pred])
87
 
 
88
  demo.launch()