data
Browse files- .gitattributes +1 -0
- app.py +5 -1
- data/LAND_MASK.tif +3 -0
- requirements.txt +2 -1
.gitattributes
CHANGED
@@ -39,3 +39,4 @@ data/species_70b.npy filter=lfs diff=lfs merge=lfs -text
|
|
39 |
data/pos_embeds_model.npy filter=lfs diff=lfs merge=lfs -text
|
40 |
model/demo_model.pt filter=lfs diff=lfs merge=lfs -text
|
41 |
data/family_70b.npy filter=lfs diff=lfs merge=lfs -text
|
|
|
|
39 |
data/pos_embeds_model.npy filter=lfs diff=lfs merge=lfs -text
|
40 |
model/demo_model.pt filter=lfs diff=lfs merge=lfs -text
|
41 |
data/family_70b.npy filter=lfs diff=lfs merge=lfs -text
|
42 |
+
data/LAND_MASK.tif filter=lfs diff=lfs merge=lfs -text
|
app.py
CHANGED
@@ -5,9 +5,13 @@ import torch.nn as nn
|
|
5 |
from einops import rearrange
|
6 |
import matplotlib.pyplot as plt
|
7 |
from copy import deepcopy
|
|
|
|
|
8 |
|
9 |
pred_global = None
|
10 |
|
|
|
|
|
11 |
class Attn(nn.Module):
|
12 |
def __init__(self, dim, dim_text, heads = 16, dim_head = 64):
|
13 |
super().__init__()
|
@@ -94,7 +98,7 @@ def text_fn(taxon, name):
|
|
94 |
text_embeds = species[()][name]
|
95 |
|
96 |
text_embeds = torch.tensor(text_embeds)
|
97 |
-
preds = model(text_embeds).sigmoid().squeeze(0).squeeze(0).detach().numpy()
|
98 |
pred_global = preds
|
99 |
cmap = plt.get_cmap('Greens')
|
100 |
|
|
|
5 |
from einops import rearrange
|
6 |
import matplotlib.pyplot as plt
|
7 |
from copy import deepcopy
|
8 |
+
import rasterio
|
9 |
+
from rasterio.enums import Resampling
|
10 |
|
11 |
pred_global = None
|
12 |
|
13 |
+
land_mask = (rasterio.open('data/LAND_MASK.tif').read(out_shape=(1, 900, 1800), resampling=Resampling.nearest) == 1).squeeze(0)
|
14 |
+
|
15 |
class Attn(nn.Module):
|
16 |
def __init__(self, dim, dim_text, heads = 16, dim_head = 64):
|
17 |
super().__init__()
|
|
|
98 |
text_embeds = species[()][name]
|
99 |
|
100 |
text_embeds = torch.tensor(text_embeds)
|
101 |
+
preds = model(text_embeds).sigmoid().squeeze(0).squeeze(0).detach().numpy() * land_mask
|
102 |
pred_global = preds
|
103 |
cmap = plt.get_cmap('Greens')
|
104 |
|
data/LAND_MASK.tif
ADDED
|
Git LFS Details
|
requirements.txt
CHANGED
@@ -2,4 +2,5 @@ numpy==1.23.4
|
|
2 |
torch==2.0.1
|
3 |
rasterio==1.3.8
|
4 |
einops==0.6.1
|
5 |
-
matplotlib
|
|
|
|
2 |
torch==2.0.1
|
3 |
rasterio==1.3.8
|
4 |
einops==0.6.1
|
5 |
+
matplotlib
|
6 |
+
rasterio
|