Vishu26 commited on
Commit
41180d8
·
1 Parent(s): cb4da0e
Files changed (4) hide show
  1. .gitattributes +1 -0
  2. app.py +6 -1
  3. data/family_70b.npy +3 -0
  4. requirements.txt +2 -1
.gitattributes CHANGED
@@ -38,3 +38,4 @@ data/order_70b.npy filter=lfs diff=lfs merge=lfs -text
38
  data/species_70b.npy filter=lfs diff=lfs merge=lfs -text
39
  data/pos_embeds_model.npy filter=lfs diff=lfs merge=lfs -text
40
  model/demo_model.pt filter=lfs diff=lfs merge=lfs -text
 
 
38
  data/species_70b.npy filter=lfs diff=lfs merge=lfs -text
39
  data/pos_embeds_model.npy filter=lfs diff=lfs merge=lfs -text
40
  model/demo_model.pt filter=lfs diff=lfs merge=lfs -text
41
+ data/family_70b.npy filter=lfs diff=lfs merge=lfs -text
app.py CHANGED
@@ -3,6 +3,7 @@ import numpy as np
3
  import torch
4
  import torch.nn as nn
5
  from einops import rearrange
 
6
 
7
 
8
  class Attn(nn.Module):
@@ -91,8 +92,12 @@ def text_fn(taxon, name):
91
 
92
  text_embeds = torch.tensor(text_embeds)
93
  preds = model(text_embeds).sigmoid().squeeze(0).squeeze(0).detach().numpy()
 
 
 
 
94
  #return gr.Image(preds, label="Predicted Heatmap", visible=True)
95
- return taxon + ": " + name + ": " + str(np.mean(preds)), preds
96
 
97
  def pred_fn(taxon, name):
98
  if taxon=="Class":
 
3
  import torch
4
  import torch.nn as nn
5
  from einops import rearrange
6
+ import matplotlib.pyplot as plt
7
 
8
 
9
  class Attn(nn.Module):
 
92
 
93
  text_embeds = torch.tensor(text_embeds)
94
  preds = model(text_embeds).sigmoid().squeeze(0).squeeze(0).detach().numpy()
95
+ cmap = plt.get_cmap('Greens')
96
+
97
+ rgba_img = cmap(preds)
98
+ rgb_img = np.delete(rgba_img, 3, 2)
99
  #return gr.Image(preds, label="Predicted Heatmap", visible=True)
100
+ return taxon + ": " + name + ": " + str(np.mean(preds)), rgb_img
101
 
102
  def pred_fn(taxon, name):
103
  if taxon=="Class":
data/family_70b.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:306af2d796edbe2b41c352bc23196ac7650f9f977fb8afbe777dc1e68c6e9b8b
3
+ size 140296975
requirements.txt CHANGED
@@ -1,4 +1,5 @@
1
  numpy==1.23.4
2
  torch==2.0.1
3
  rasterio==1.3.8
4
- einops==0.6.1
 
 
1
  numpy==1.23.4
2
  torch==2.0.1
3
  rasterio==1.3.8
4
+ einops==0.6.1
5
+ matplotlib