Vishu26 commited on
Commit
b5844ee
·
1 Parent(s): 4d54978
Files changed (4) hide show
  1. .gitattributes +1 -0
  2. app.py +5 -1
  3. data/LAND_MASK.tif +3 -0
  4. requirements.txt +2 -1
.gitattributes CHANGED
@@ -39,3 +39,4 @@ data/species_70b.npy filter=lfs diff=lfs merge=lfs -text
39
  data/pos_embeds_model.npy filter=lfs diff=lfs merge=lfs -text
40
  model/demo_model.pt filter=lfs diff=lfs merge=lfs -text
41
  data/family_70b.npy filter=lfs diff=lfs merge=lfs -text
 
 
39
  data/pos_embeds_model.npy filter=lfs diff=lfs merge=lfs -text
40
  model/demo_model.pt filter=lfs diff=lfs merge=lfs -text
41
  data/family_70b.npy filter=lfs diff=lfs merge=lfs -text
42
+ data/LAND_MASK.tif filter=lfs diff=lfs merge=lfs -text
app.py CHANGED
@@ -5,9 +5,13 @@ import torch.nn as nn
5
  from einops import rearrange
6
  import matplotlib.pyplot as plt
7
  from copy import deepcopy
 
 
8
 
9
  pred_global = None
10
 
 
 
11
  class Attn(nn.Module):
12
  def __init__(self, dim, dim_text, heads = 16, dim_head = 64):
13
  super().__init__()
@@ -94,7 +98,7 @@ def text_fn(taxon, name):
94
  text_embeds = species[()][name]
95
 
96
  text_embeds = torch.tensor(text_embeds)
97
- preds = model(text_embeds).sigmoid().squeeze(0).squeeze(0).detach().numpy()
98
  pred_global = preds
99
  cmap = plt.get_cmap('Greens')
100
 
 
5
  from einops import rearrange
6
  import matplotlib.pyplot as plt
7
  from copy import deepcopy
8
+ import rasterio
9
+ from rasterio.enums import Resampling
10
 
11
  pred_global = None
12
 
13
+ land_mask = (rasterio.open('data/LAND_MASK.tif').read(out_shape=(1, 900, 1800), resampling=Resampling.nearest) == 1).squeeze(0)
14
+
15
  class Attn(nn.Module):
16
  def __init__(self, dim, dim_text, heads = 16, dim_head = 64):
17
  super().__init__()
 
98
  text_embeds = species[()][name]
99
 
100
  text_embeds = torch.tensor(text_embeds)
101
+ preds = model(text_embeds).sigmoid().squeeze(0).squeeze(0).detach().numpy() * land_mask
102
  pred_global = preds
103
  cmap = plt.get_cmap('Greens')
104
 
data/LAND_MASK.tif ADDED

Git LFS Details

  • SHA256: e260bd7833856c3a53b912fa8ebb35fb9850942006dd83579797c7cda78b8414
  • Pointer size: 132 Bytes
  • Size of remote file: 8.03 MB
requirements.txt CHANGED
@@ -2,4 +2,5 @@ numpy==1.23.4
2
  torch==2.0.1
3
  rasterio==1.3.8
4
  einops==0.6.1
5
- matplotlib
 
 
2
  torch==2.0.1
3
  rasterio==1.3.8
4
  einops==0.6.1
5
+ matplotlib
6
+ rasterio