jonathan-glider
commited on
Commit
•
ced6b43
1
Parent(s):
c225b3b
Upload 14 files
Browse files- .gitattributes +6 -0
- Korea_data.zip +3 -0
- Prithvi.py +319 -0
- Prithvi_100M.pt +3 -0
- Prithvi_100M_config.yaml +19 -0
- Prithvi_run_inference.py +399 -0
- README.md +41 -0
- app.py +483 -0
- first.tif +3 -0
- requirements.txt +5 -0
- second.tif +3 -0
- streamlit-testing.webm +3 -0
- temp/1ab7b057-9543-4240-92a1-f85bba853af6.jpg +3 -0
- temp/771eca76-489a-41c7-bcfb-28b841e78dd7.jpg +3 -0
- third.tif +3 -0
.gitattributes
CHANGED
@@ -33,3 +33,9 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
first.tif filter=lfs diff=lfs merge=lfs -text
|
37 |
+
second.tif filter=lfs diff=lfs merge=lfs -text
|
38 |
+
streamlit-testing.webm filter=lfs diff=lfs merge=lfs -text
|
39 |
+
temp/1ab7b057-9543-4240-92a1-f85bba853af6.jpg filter=lfs diff=lfs merge=lfs -text
|
40 |
+
temp/771eca76-489a-41c7-bcfb-28b841e78dd7.jpg filter=lfs diff=lfs merge=lfs -text
|
41 |
+
third.tif filter=lfs diff=lfs merge=lfs -text
|
Korea_data.zip
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9384fd7b5ee9e5aaf80c7251469e0fd68925b5cb8e1a5fabea8fd2cd8bb7c9bd
|
3 |
+
size 483473233
|
Prithvi.py
ADDED
@@ -0,0 +1,319 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
# --------------------------------------------------------
|
7 |
+
# References:
|
8 |
+
# timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm
|
9 |
+
# DeiT: https://github.com/facebookresearch/deit
|
10 |
+
# --------------------------------------------------------
|
11 |
+
|
12 |
+
from functools import partial
|
13 |
+
|
14 |
+
import torch
|
15 |
+
import torch.nn as nn
|
16 |
+
|
17 |
+
from timm.models.vision_transformer import Block
|
18 |
+
from timm.models.layers import to_2tuple
|
19 |
+
|
20 |
+
import numpy as np
|
21 |
+
|
22 |
+
from einops import rearrange
|
23 |
+
|
24 |
+
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
25 |
+
"""
|
26 |
+
embed_dim: output dimension for each position
|
27 |
+
pos: a list of positions to be encoded: size (M,)
|
28 |
+
out: (M, D)
|
29 |
+
"""
|
30 |
+
assert embed_dim % 2 == 0
|
31 |
+
omega = np.arange(embed_dim // 2, dtype=np.float32)
|
32 |
+
omega /= embed_dim / 2.
|
33 |
+
omega = 1. / 10000**omega # (D/2,)
|
34 |
+
|
35 |
+
pos = pos.reshape(-1) # (M,)
|
36 |
+
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
|
37 |
+
|
38 |
+
emb_sin = np.sin(out) # (M, D/2)
|
39 |
+
emb_cos = np.cos(out) # (M, D/2)
|
40 |
+
|
41 |
+
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
42 |
+
return emb
|
43 |
+
|
44 |
+
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
45 |
+
assert embed_dim % 2 == 0
|
46 |
+
|
47 |
+
# use half of dimensions to encode grid_h
|
48 |
+
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
49 |
+
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
50 |
+
|
51 |
+
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
52 |
+
return emb
|
53 |
+
|
54 |
+
def get_3d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
|
55 |
+
"""
|
56 |
+
grid_size: 3d tuple of grid size: t, h, w
|
57 |
+
return:
|
58 |
+
pos_embed: L, D
|
59 |
+
"""
|
60 |
+
|
61 |
+
assert embed_dim % 16 == 0
|
62 |
+
|
63 |
+
t_size, h_size, w_size = grid_size
|
64 |
+
|
65 |
+
w_embed_dim = embed_dim // 16 * 6
|
66 |
+
h_embed_dim = embed_dim // 16 * 6
|
67 |
+
t_embed_dim = embed_dim // 16 * 4
|
68 |
+
|
69 |
+
w_pos_embed = get_1d_sincos_pos_embed_from_grid(w_embed_dim, np.arange(w_size))
|
70 |
+
h_pos_embed = get_1d_sincos_pos_embed_from_grid(h_embed_dim, np.arange(h_size))
|
71 |
+
t_pos_embed = get_1d_sincos_pos_embed_from_grid(t_embed_dim, np.arange(t_size))
|
72 |
+
|
73 |
+
w_pos_embed = np.tile(w_pos_embed, (t_size * h_size, 1))
|
74 |
+
h_pos_embed = np.tile(np.repeat(h_pos_embed, w_size, axis=0), (t_size, 1))
|
75 |
+
t_pos_embed = np.repeat(t_pos_embed, h_size * w_size, axis=0)
|
76 |
+
|
77 |
+
pos_embed = np.concatenate((w_pos_embed, h_pos_embed, t_pos_embed), axis=1)
|
78 |
+
|
79 |
+
if cls_token:
|
80 |
+
pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
|
81 |
+
return pos_embed
|
82 |
+
|
83 |
+
|
84 |
+
class PatchEmbed(nn.Module):
|
85 |
+
""" Frames of 2D Images to Patch Embedding
|
86 |
+
The 3D version of timm.models.vision_transformer.PatchEmbed
|
87 |
+
"""
|
88 |
+
def __init__(
|
89 |
+
self,
|
90 |
+
img_size=224,
|
91 |
+
patch_size=16,
|
92 |
+
num_frames=3,
|
93 |
+
tubelet_size=1,
|
94 |
+
in_chans=3,
|
95 |
+
embed_dim=768,
|
96 |
+
norm_layer=None,
|
97 |
+
flatten=True,
|
98 |
+
bias=True,
|
99 |
+
):
|
100 |
+
super().__init__()
|
101 |
+
img_size = to_2tuple(img_size)
|
102 |
+
patch_size = to_2tuple(patch_size)
|
103 |
+
self.img_size = img_size
|
104 |
+
self.patch_size = patch_size
|
105 |
+
self.num_frames = num_frames
|
106 |
+
self.tubelet_size = tubelet_size
|
107 |
+
self.grid_size = (num_frames // tubelet_size, img_size[0] // patch_size[0], img_size[1] // patch_size[1])
|
108 |
+
self.num_patches = self.grid_size[0] * self.grid_size[1] * self.grid_size[2]
|
109 |
+
self.flatten = flatten
|
110 |
+
|
111 |
+
self.proj = nn.Conv3d(in_chans, embed_dim,
|
112 |
+
kernel_size=(tubelet_size, patch_size[0], patch_size[1]),
|
113 |
+
stride=(tubelet_size, patch_size[0], patch_size[1]), bias=bias)
|
114 |
+
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
115 |
+
|
116 |
+
def forward(self, x):
|
117 |
+
B, C, T, H, W = x.shape
|
118 |
+
x = self.proj(x)
|
119 |
+
if self.flatten:
|
120 |
+
x = x.flatten(2).transpose(1, 2) # B,C,T,H,W -> B,C,L -> B,L,C
|
121 |
+
x = self.norm(x)
|
122 |
+
return x
|
123 |
+
|
124 |
+
|
125 |
+
class MaskedAutoencoderViT(nn.Module):
|
126 |
+
""" Masked Autoencoder with VisionTransformer backbone
|
127 |
+
"""
|
128 |
+
def __init__(self, img_size=224, patch_size=16,
|
129 |
+
num_frames=3, tubelet_size=1,
|
130 |
+
in_chans=3, embed_dim=1024, depth=24, num_heads=16,
|
131 |
+
decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
|
132 |
+
mlp_ratio=4., norm_layer=nn.LayerNorm, norm_pix_loss=False):
|
133 |
+
super().__init__()
|
134 |
+
|
135 |
+
# --------------------------------------------------------------------------
|
136 |
+
# MAE encoder specifics
|
137 |
+
self.patch_embed = PatchEmbed(img_size, patch_size,num_frames, tubelet_size, in_chans, embed_dim)
|
138 |
+
num_patches = self.patch_embed.num_patches
|
139 |
+
|
140 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
141 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False) # fixed sin-cos embedding
|
142 |
+
|
143 |
+
self.blocks = nn.ModuleList([
|
144 |
+
Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer)
|
145 |
+
for i in range(depth)])
|
146 |
+
self.norm = norm_layer(embed_dim)
|
147 |
+
# --------------------------------------------------------------------------
|
148 |
+
|
149 |
+
# --------------------------------------------------------------------------
|
150 |
+
# MAE decoder specifics
|
151 |
+
self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True)
|
152 |
+
|
153 |
+
self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
|
154 |
+
|
155 |
+
self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, decoder_embed_dim), requires_grad=False) # fixed sin-cos embedding
|
156 |
+
|
157 |
+
self.decoder_blocks = nn.ModuleList([
|
158 |
+
Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer)
|
159 |
+
for i in range(decoder_depth)])
|
160 |
+
|
161 |
+
self.decoder_norm = norm_layer(decoder_embed_dim)
|
162 |
+
self.decoder_pred = nn.Linear(decoder_embed_dim, tubelet_size * patch_size * patch_size * in_chans, bias=True) # decoder to patch
|
163 |
+
# --------------------------------------------------------------------------
|
164 |
+
|
165 |
+
self.norm_pix_loss = norm_pix_loss
|
166 |
+
|
167 |
+
self.initialize_weights()
|
168 |
+
|
169 |
+
def initialize_weights(self):
|
170 |
+
# initialization
|
171 |
+
# initialize (and freeze) pos_embed by sin-cos embedding
|
172 |
+
pos_embed = get_3d_sincos_pos_embed(self.pos_embed.shape[-1], self.patch_embed.grid_size, cls_token=True)
|
173 |
+
self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
|
174 |
+
|
175 |
+
decoder_pos_embed = get_3d_sincos_pos_embed(self.decoder_pos_embed.shape[-1], self.patch_embed.grid_size, cls_token=True)
|
176 |
+
self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0))
|
177 |
+
|
178 |
+
# initialize patch_embed like nn.Linear (instead of nn.Conv2d)
|
179 |
+
w = self.patch_embed.proj.weight.data
|
180 |
+
torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
181 |
+
|
182 |
+
# timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
|
183 |
+
torch.nn.init.normal_(self.cls_token, std=.02)
|
184 |
+
torch.nn.init.normal_(self.mask_token, std=.02)
|
185 |
+
|
186 |
+
# initialize nn.Linear and nn.LayerNorm
|
187 |
+
self.apply(self._init_weights)
|
188 |
+
|
189 |
+
def _init_weights(self, m):
|
190 |
+
if isinstance(m, nn.Linear):
|
191 |
+
# we use xavier_uniform following official JAX ViT:
|
192 |
+
torch.nn.init.xavier_uniform_(m.weight)
|
193 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
194 |
+
nn.init.constant_(m.bias, 0)
|
195 |
+
elif isinstance(m, nn.LayerNorm):
|
196 |
+
nn.init.constant_(m.bias, 0)
|
197 |
+
nn.init.constant_(m.weight, 1.0)
|
198 |
+
|
199 |
+
def patchify(self, imgs):
|
200 |
+
"""
|
201 |
+
imgs: B, C, T, H, W
|
202 |
+
x: B, L, D
|
203 |
+
"""
|
204 |
+
p = self.patch_embed.patch_size[0]
|
205 |
+
tub = self.patch_embed.tubelet_size
|
206 |
+
x = rearrange(imgs, 'b c (t tub) (h p) (w q) -> b (t h w) (tub p q c)', tub=tub, p=p, q=p)
|
207 |
+
|
208 |
+
return x
|
209 |
+
|
210 |
+
def unpatchify(self, x):
|
211 |
+
"""
|
212 |
+
x: B, L, D
|
213 |
+
imgs: B, C, T, H, W
|
214 |
+
"""
|
215 |
+
p = self.patch_embed.patch_size[0]
|
216 |
+
num_p = self.patch_embed.img_size[0] // p
|
217 |
+
tub = self.patch_embed.tubelet_size
|
218 |
+
imgs = rearrange(x, 'b (t h w) (tub p q c) -> b c (t tub) (h p) (w q)', h=num_p, w=num_p, tub=tub, p=p, q=p)
|
219 |
+
return imgs
|
220 |
+
|
221 |
+
def random_masking(self, x, mask_ratio):
|
222 |
+
"""
|
223 |
+
Perform per-sample random masking by per-sample shuffling.
|
224 |
+
Per-sample shuffling is done by argsort random noise.
|
225 |
+
x: [N, L, D], sequence
|
226 |
+
"""
|
227 |
+
N, L, D = x.shape # batch, length, dim
|
228 |
+
len_keep = int(L * (1 - mask_ratio))
|
229 |
+
|
230 |
+
noise = torch.rand(N, L, device=x.device) # noise in [0, 1]
|
231 |
+
|
232 |
+
# sort noise for each sample
|
233 |
+
ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove
|
234 |
+
ids_restore = torch.argsort(ids_shuffle, dim=1)
|
235 |
+
|
236 |
+
# keep the first subset
|
237 |
+
ids_keep = ids_shuffle[:, :len_keep]
|
238 |
+
x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
|
239 |
+
|
240 |
+
# generate the binary mask: 0 is keep, 1 is remove
|
241 |
+
mask = torch.ones([N, L], device=x.device)
|
242 |
+
mask[:, :len_keep] = 0
|
243 |
+
# unshuffle to get the binary mask
|
244 |
+
mask = torch.gather(mask, dim=1, index=ids_restore)
|
245 |
+
|
246 |
+
return x_masked, mask, ids_restore
|
247 |
+
|
248 |
+
def forward_encoder(self, x, mask_ratio):
|
249 |
+
# embed patches
|
250 |
+
x = self.patch_embed(x)
|
251 |
+
|
252 |
+
# add pos embed w/o cls token
|
253 |
+
x = x + self.pos_embed[:, 1:, :]
|
254 |
+
|
255 |
+
# masking: length -> length * mask_ratio
|
256 |
+
x, mask, ids_restore = self.random_masking(x, mask_ratio)
|
257 |
+
|
258 |
+
# append cls token
|
259 |
+
cls_token = self.cls_token + self.pos_embed[:, :1, :]
|
260 |
+
cls_tokens = cls_token.expand(x.shape[0], -1, -1)
|
261 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
262 |
+
|
263 |
+
# apply Transformer blocks
|
264 |
+
for blk in self.blocks:
|
265 |
+
x = blk(x)
|
266 |
+
x = self.norm(x)
|
267 |
+
|
268 |
+
return x, mask, ids_restore
|
269 |
+
|
270 |
+
def forward_decoder(self, x, ids_restore):
|
271 |
+
# embed tokens
|
272 |
+
x = self.decoder_embed(x)
|
273 |
+
|
274 |
+
# append mask tokens to sequence
|
275 |
+
mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)
|
276 |
+
x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token
|
277 |
+
x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle
|
278 |
+
x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token
|
279 |
+
|
280 |
+
# add pos embed
|
281 |
+
x = x + self.decoder_pos_embed
|
282 |
+
|
283 |
+
# apply Transformer blocks
|
284 |
+
for blk in self.decoder_blocks:
|
285 |
+
x = blk(x)
|
286 |
+
x = self.decoder_norm(x)
|
287 |
+
|
288 |
+
# predictor projection
|
289 |
+
x = self.decoder_pred(x)
|
290 |
+
|
291 |
+
# remove cls token
|
292 |
+
x = x[:, 1:, :]
|
293 |
+
|
294 |
+
return x
|
295 |
+
|
296 |
+
def forward_loss(self, imgs, pred, mask):
|
297 |
+
"""
|
298 |
+
imgs: B, C, T, H, W
|
299 |
+
target: B, L, D
|
300 |
+
pred: B, L, D
|
301 |
+
mask: B, L. 0 is keep, 1 is remove,
|
302 |
+
"""
|
303 |
+
target = self.patchify(imgs)
|
304 |
+
if self.norm_pix_loss:
|
305 |
+
mean = target.mean(dim=-1, keepdim=True)
|
306 |
+
var = target.var(dim=-1, keepdim=True)
|
307 |
+
target = (target - mean) / (var + 1.e-6)**.5
|
308 |
+
|
309 |
+
loss = (pred - target) ** 2
|
310 |
+
loss = loss.mean(dim=-1) # [N, L], mean loss per patch
|
311 |
+
|
312 |
+
loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches
|
313 |
+
return loss
|
314 |
+
|
315 |
+
def forward(self, imgs, mask_ratio=0.75):
|
316 |
+
latent, mask, ids_restore = self.forward_encoder(imgs, mask_ratio)
|
317 |
+
pred = self.forward_decoder(latent, ids_restore)
|
318 |
+
loss = self.forward_loss(imgs, pred, mask)
|
319 |
+
return loss, pred, mask
|
Prithvi_100M.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:69f8ac286f649d1bbed520f5c8560a60eba91d688f74e1a0f9aa8203b6fd62ab
|
3 |
+
size 453672901
|
Prithvi_100M_config.yaml
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
num_frames: 3
|
2 |
+
img_size: 224
|
3 |
+
bands: [B02, B03, B04, B05, B06, B07]
|
4 |
+
random_cropping: true
|
5 |
+
data_loader_num_workers: 1
|
6 |
+
depth: 12
|
7 |
+
decoder_depth: 8
|
8 |
+
patch_size: 16
|
9 |
+
embed_dim: 768
|
10 |
+
decoder_embed_dim: 512
|
11 |
+
num_heads: 12
|
12 |
+
decoder_num_heads: 16
|
13 |
+
mask_ratio: 0.75
|
14 |
+
tubelet_size: 1
|
15 |
+
data_mean: [775.2290211032589, 1080.992780391705, 1228.5855250417867, 2497.2022620507532,
|
16 |
+
2204.2139147975554, 1610.8324823273745]
|
17 |
+
data_std: [1281.526139861424, 1270.0297974547493, 1399.4802505642526, 1368.3446143747644,
|
18 |
+
1291.6764008585435, 1154.505683480695]
|
19 |
+
batch_size: 16
|
Prithvi_run_inference.py
ADDED
@@ -0,0 +1,399 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import functools
|
3 |
+
import os
|
4 |
+
from typing import List
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import rasterio
|
8 |
+
import torch
|
9 |
+
import yaml
|
10 |
+
from einops import rearrange
|
11 |
+
|
12 |
+
from Prithvi import MaskedAutoencoderViT
|
13 |
+
|
14 |
+
NO_DATA = -9999
|
15 |
+
NO_DATA_FLOAT = 0.0001
|
16 |
+
PERCENTILES = (0.1, 99.9)
|
17 |
+
|
18 |
+
|
19 |
+
def process_channel_group(orig_img, new_img, channels, data_mean, data_std):
|
20 |
+
""" Process *orig_img* and *new_img* for RGB visualization. Each band is rescaled back to the
|
21 |
+
original range using *data_mean* and *data_std* and then lowest and highest percentiles are
|
22 |
+
removed to enhance contrast. Data is rescaled to (0, 1) range and stacked channels_first.
|
23 |
+
|
24 |
+
Args:
|
25 |
+
orig_img: torch.Tensor representing original image (reference) with shape = (bands, H, W).
|
26 |
+
new_img: torch.Tensor representing image with shape = (bands, H, W).
|
27 |
+
channels: list of indices representing RGB channels.
|
28 |
+
data_mean: list of mean values for each band.
|
29 |
+
data_std: list of std values for each band.
|
30 |
+
|
31 |
+
Returns:
|
32 |
+
torch.Tensor with shape (num_channels, height, width) for original image
|
33 |
+
torch.Tensor with shape (num_channels, height, width) for the other image
|
34 |
+
"""
|
35 |
+
|
36 |
+
stack_c = [], []
|
37 |
+
|
38 |
+
for c in channels:
|
39 |
+
orig_ch = orig_img[c, ...]
|
40 |
+
valid_mask = torch.ones_like(orig_ch, dtype=torch.bool)
|
41 |
+
valid_mask[orig_ch == NO_DATA_FLOAT] = False
|
42 |
+
|
43 |
+
# Back to original data range
|
44 |
+
orig_ch = (orig_ch * data_std[c]) + data_mean[c]
|
45 |
+
new_ch = (new_img[c, ...] * data_std[c]) + data_mean[c]
|
46 |
+
|
47 |
+
# Rescale (enhancing contrast)
|
48 |
+
min_value, max_value = np.percentile(orig_ch[valid_mask], PERCENTILES)
|
49 |
+
|
50 |
+
orig_ch = torch.clamp((orig_ch - min_value) / (max_value - min_value), 0, 1)
|
51 |
+
new_ch = torch.clamp((new_ch - min_value) / (max_value - min_value), 0, 1)
|
52 |
+
|
53 |
+
# No data as zeros
|
54 |
+
orig_ch[~valid_mask] = 0
|
55 |
+
new_ch[~valid_mask] = 0
|
56 |
+
|
57 |
+
stack_c[0].append(orig_ch)
|
58 |
+
stack_c[1].append(new_ch)
|
59 |
+
|
60 |
+
# Channels first
|
61 |
+
stack_orig = torch.stack(stack_c[0], dim=0)
|
62 |
+
stack_rec = torch.stack(stack_c[1], dim=0)
|
63 |
+
|
64 |
+
return stack_orig, stack_rec
|
65 |
+
|
66 |
+
|
67 |
+
def read_geotiff(file_path: str):
|
68 |
+
""" Read all bands from *file_path* and return image + meta info.
|
69 |
+
|
70 |
+
Args:
|
71 |
+
file_path: path to image file.
|
72 |
+
|
73 |
+
Returns:
|
74 |
+
np.ndarray with shape (bands, height, width)
|
75 |
+
meta info dict
|
76 |
+
"""
|
77 |
+
|
78 |
+
with rasterio.open(file_path) as src:
|
79 |
+
img = src.read()
|
80 |
+
meta = src.meta
|
81 |
+
|
82 |
+
return img, meta
|
83 |
+
|
84 |
+
|
85 |
+
def save_geotiff(image, output_path: str, meta: dict):
|
86 |
+
""" Save multi-band image in Geotiff file.
|
87 |
+
|
88 |
+
Args:
|
89 |
+
image: np.ndarray with shape (bands, height, width)
|
90 |
+
output_path: path where to save the image
|
91 |
+
meta: dict with meta info.
|
92 |
+
"""
|
93 |
+
|
94 |
+
with rasterio.open(output_path, "w", **meta) as dest:
|
95 |
+
for i in range(image.shape[0]):
|
96 |
+
dest.write(image[i, :, :], i + 1)
|
97 |
+
|
98 |
+
return
|
99 |
+
|
100 |
+
|
101 |
+
def _convert_np_uint8(float_image: torch.Tensor):
|
102 |
+
|
103 |
+
image = float_image.numpy() * 255.0
|
104 |
+
image = image.astype(dtype=np.uint8)
|
105 |
+
|
106 |
+
return image
|
107 |
+
|
108 |
+
|
109 |
+
def load_example(file_paths: List[str], mean: List[float], std: List[float]):
|
110 |
+
""" Build an input example by loading images in *file_paths*.
|
111 |
+
|
112 |
+
Args:
|
113 |
+
file_paths: list of file paths .
|
114 |
+
mean: list containing mean values for each band in the images in *file_paths*.
|
115 |
+
std: list containing std values for each band in the images in *file_paths*.
|
116 |
+
|
117 |
+
Returns:
|
118 |
+
np.array containing created example
|
119 |
+
list of meta info for each image in *file_paths*
|
120 |
+
"""
|
121 |
+
|
122 |
+
imgs = []
|
123 |
+
metas = []
|
124 |
+
|
125 |
+
for file in file_paths:
|
126 |
+
img, meta = read_geotiff(file)
|
127 |
+
|
128 |
+
# Rescaling (don't normalize on nodata)
|
129 |
+
img = np.moveaxis(img, 0, -1) # channels last for rescaling
|
130 |
+
img = np.where(img == NO_DATA, NO_DATA_FLOAT, (img - mean) / std)
|
131 |
+
|
132 |
+
imgs.append(img)
|
133 |
+
metas.append(meta)
|
134 |
+
|
135 |
+
imgs = np.stack(imgs, axis=0) # num_frames, H, W, C
|
136 |
+
imgs = np.moveaxis(imgs, -1, 0).astype('float32') # C, num_frames, H, W
|
137 |
+
imgs = np.expand_dims(imgs, axis=0) # add batch dim
|
138 |
+
|
139 |
+
return imgs, metas
|
140 |
+
|
141 |
+
|
142 |
+
def run_model(model: torch.nn.Module, input_data: torch.Tensor, mask_ratio: float, device: torch.device):
|
143 |
+
""" Run *model* with *input_data* and create images from output tokens (mask, reconstructed + visible).
|
144 |
+
|
145 |
+
Args:
|
146 |
+
model: MAE model to run.
|
147 |
+
input_data: torch.Tensor with shape (B, C, T, H, W).
|
148 |
+
mask_ratio: mask ratio to use.
|
149 |
+
device: device where model should run.
|
150 |
+
|
151 |
+
Returns:
|
152 |
+
3 torch.Tensor with shape (B, C, T, H, W).
|
153 |
+
"""
|
154 |
+
|
155 |
+
with torch.no_grad():
|
156 |
+
x = input_data.to(device)
|
157 |
+
|
158 |
+
_, pred, mask = model(x, mask_ratio)
|
159 |
+
|
160 |
+
# Create mask and prediction images (un-patchify)
|
161 |
+
mask_img = model.unpatchify(mask.unsqueeze(-1).repeat(1, 1, pred.shape[-1])).detach().cpu()
|
162 |
+
pred_img = model.unpatchify(pred).detach().cpu()
|
163 |
+
|
164 |
+
# Mix visible and predicted patches
|
165 |
+
rec_img = input_data.clone()
|
166 |
+
rec_img[mask_img == 1] = pred_img[mask_img == 1] # binary mask: 0 is keep, 1 is remove
|
167 |
+
|
168 |
+
# Switch zeros/ones in mask images so masked patches appear darker in plots (better visualization)
|
169 |
+
mask_img = (~(mask_img.to(torch.bool))).to(torch.float)
|
170 |
+
|
171 |
+
return rec_img, mask_img
|
172 |
+
|
173 |
+
|
174 |
+
def save_rgb_imgs(input_img, rec_img, mask_img, channels, mean, std, output_dir, meta_data):
|
175 |
+
""" Wrapper function to save Geotiff images (original, reconstructed, masked) per timestamp.
|
176 |
+
|
177 |
+
Args:
|
178 |
+
input_img: input torch.Tensor with shape (C, T, H, W).
|
179 |
+
rec_img: reconstructed torch.Tensor with shape (C, T, H, W).
|
180 |
+
mask_img: mask torch.Tensor with shape (C, T, H, W).
|
181 |
+
channels: list of indices representing RGB channels.
|
182 |
+
mean: list of mean values for each band.
|
183 |
+
std: list of std values for each band.
|
184 |
+
output_dir: directory where to save outputs.
|
185 |
+
meta_data: list of dicts with geotiff meta info.
|
186 |
+
"""
|
187 |
+
|
188 |
+
for t in range(input_img.shape[1]):
|
189 |
+
rgb_orig, rgb_pred = process_channel_group(orig_img=input_img[:, t, :, :],
|
190 |
+
new_img=rec_img[:, t, :, :],
|
191 |
+
channels=channels, data_mean=mean,
|
192 |
+
data_std=std)
|
193 |
+
|
194 |
+
rgb_mask = mask_img[channels, t, :, :] * rgb_orig
|
195 |
+
|
196 |
+
# Saving images
|
197 |
+
|
198 |
+
save_geotiff(image=_convert_np_uint8(rgb_orig),
|
199 |
+
output_path=os.path.join(output_dir, f"original_rgb_t{t}.tiff"),
|
200 |
+
meta=meta_data[t])
|
201 |
+
|
202 |
+
save_geotiff(image=_convert_np_uint8(rgb_pred),
|
203 |
+
output_path=os.path.join(output_dir, f"predicted_rgb_t{t}.tiff"),
|
204 |
+
meta=meta_data[t])
|
205 |
+
|
206 |
+
save_geotiff(image=_convert_np_uint8(rgb_mask),
|
207 |
+
output_path=os.path.join(output_dir, f"masked_rgb_t{t}.tiff"),
|
208 |
+
meta=meta_data[t])
|
209 |
+
|
210 |
+
|
211 |
+
def save_imgs(rec_img, mask_img, mean, std, output_dir, meta_data):
|
212 |
+
""" Wrapper function to save Geotiff images (reconstructed, mask) per timestamp.
|
213 |
+
|
214 |
+
Args:
|
215 |
+
rec_img: reconstructed torch.Tensor with shape (C, T, H, W).
|
216 |
+
mask_img: mask torch.Tensor with shape (C, T, H, W).
|
217 |
+
mean: list of mean values for each band.
|
218 |
+
std: list of std values for each band.
|
219 |
+
output_dir: directory where to save outputs.
|
220 |
+
meta_data: list of dicts with geotiff meta info.
|
221 |
+
"""
|
222 |
+
|
223 |
+
mean = torch.tensor(np.asarray(mean)[:, None, None]) # C H W
|
224 |
+
std = torch.tensor(np.asarray(std)[:, None, None])
|
225 |
+
|
226 |
+
for t in range(rec_img.shape[1]):
|
227 |
+
|
228 |
+
# Back to original data range
|
229 |
+
rec_img_t = ((rec_img[:, t, :, :] * std) + mean).to(torch.int16)
|
230 |
+
|
231 |
+
mask_img_t = mask_img[:, t, :, :].to(torch.int16)
|
232 |
+
|
233 |
+
# Saving images
|
234 |
+
|
235 |
+
save_geotiff(image=rec_img_t,
|
236 |
+
output_path=os.path.join(output_dir, f"predicted_t{t}.tiff"),
|
237 |
+
meta=meta_data[t])
|
238 |
+
|
239 |
+
save_geotiff(image=mask_img_t,
|
240 |
+
output_path=os.path.join(output_dir, f"mask_t{t}.tiff"),
|
241 |
+
meta=meta_data[t])
|
242 |
+
|
243 |
+
|
244 |
+
def main(data_files: List[str], yaml_file_path: str, checkpoint: str, output_dir: str,
|
245 |
+
mask_ratio: float, rgb_outputs: bool):
|
246 |
+
|
247 |
+
os.makedirs(output_dir, exist_ok=True)
|
248 |
+
|
249 |
+
# Get parameters --------
|
250 |
+
|
251 |
+
with open(yaml_file_path, 'r') as f:
|
252 |
+
params = yaml.safe_load(f)
|
253 |
+
|
254 |
+
# data related
|
255 |
+
num_frames = len(data_files)
|
256 |
+
img_size = params['img_size']
|
257 |
+
bands = params['bands']
|
258 |
+
mean = params['data_mean']
|
259 |
+
std = params['data_std']
|
260 |
+
|
261 |
+
# model related
|
262 |
+
depth = params['depth']
|
263 |
+
patch_size = params['patch_size']
|
264 |
+
embed_dim = params['embed_dim']
|
265 |
+
num_heads = params['num_heads']
|
266 |
+
tubelet_size = params['tubelet_size']
|
267 |
+
decoder_embed_dim = params['decoder_embed_dim']
|
268 |
+
decoder_num_heads = params['decoder_num_heads']
|
269 |
+
decoder_depth = params['decoder_depth']
|
270 |
+
|
271 |
+
batch_size = params['batch_size']
|
272 |
+
|
273 |
+
mask_ratio = params['mask_ratio'] if mask_ratio is None else mask_ratio
|
274 |
+
|
275 |
+
print(f"\nTreating {len(data_files)} files as {len(data_files)} time steps from the same location\n")
|
276 |
+
if len(data_files) != 3:
|
277 |
+
print("The original model was trained for 3 time steps (expecting 3 files). \nResults with different numbers of timesteps may vary")
|
278 |
+
|
279 |
+
if torch.cuda.is_available():
|
280 |
+
device = torch.device('cuda')
|
281 |
+
else:
|
282 |
+
device = torch.device('cpu')
|
283 |
+
|
284 |
+
print(f"Using {device} device.\n")
|
285 |
+
|
286 |
+
# Loading data ---------------------------------------------------------------------------------
|
287 |
+
|
288 |
+
input_data, meta_data = load_example(file_paths=data_files, mean=mean, std=std)
|
289 |
+
|
290 |
+
# Create model and load checkpoint -------------------------------------------------------------
|
291 |
+
|
292 |
+
model = MaskedAutoencoderViT(
|
293 |
+
img_size=img_size,
|
294 |
+
patch_size=patch_size,
|
295 |
+
num_frames=num_frames,
|
296 |
+
tubelet_size=tubelet_size,
|
297 |
+
in_chans=len(bands),
|
298 |
+
embed_dim=embed_dim,
|
299 |
+
depth=depth,
|
300 |
+
num_heads=num_heads,
|
301 |
+
decoder_embed_dim=decoder_embed_dim,
|
302 |
+
decoder_depth=decoder_depth,
|
303 |
+
decoder_num_heads=decoder_num_heads,
|
304 |
+
mlp_ratio=4.,
|
305 |
+
norm_layer=functools.partial(torch.nn.LayerNorm, eps=1e-6),
|
306 |
+
norm_pix_loss=False)
|
307 |
+
|
308 |
+
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
309 |
+
print(f"\n--> Model has {total_params:,} parameters.\n")
|
310 |
+
|
311 |
+
model.to(device)
|
312 |
+
|
313 |
+
state_dict = torch.load(checkpoint, map_location=device)
|
314 |
+
# discard fixed pos_embedding weight
|
315 |
+
del state_dict['pos_embed']
|
316 |
+
del state_dict['decoder_pos_embed']
|
317 |
+
model.load_state_dict(state_dict, strict=False)
|
318 |
+
print(f"Loaded checkpoint from {checkpoint}")
|
319 |
+
|
320 |
+
# Running model --------------------------------------------------------------------------------
|
321 |
+
|
322 |
+
model.eval()
|
323 |
+
channels = [bands.index(b) for b in ['B04', 'B03', 'B02']] # BGR -> RGB
|
324 |
+
|
325 |
+
# Reflect pad if not divisible by img_size
|
326 |
+
original_h, original_w = input_data.shape[-2:]
|
327 |
+
pad_h = img_size - (original_h % img_size)
|
328 |
+
pad_w = img_size - (original_w % img_size)
|
329 |
+
input_data = np.pad(input_data, ((0, 0), (0, 0), (0, 0), (0, pad_h), (0, pad_w)), mode='reflect')
|
330 |
+
|
331 |
+
# Build sliding window
|
332 |
+
batch = torch.tensor(input_data, device='cpu')
|
333 |
+
windows = batch.unfold(3, img_size, img_size).unfold(4, img_size, img_size)
|
334 |
+
h1, w1 = windows.shape[3:5]
|
335 |
+
windows = rearrange(windows, 'b c t h1 w1 h w -> (b h1 w1) c t h w', h=img_size, w=img_size)
|
336 |
+
|
337 |
+
# Split into batches if number of windows > batch_size
|
338 |
+
num_batches = windows.shape[0] // batch_size if windows.shape[0] > batch_size else 1
|
339 |
+
windows = torch.tensor_split(windows, num_batches, dim=0)
|
340 |
+
|
341 |
+
# Run model
|
342 |
+
rec_imgs = []
|
343 |
+
mask_imgs = []
|
344 |
+
for x in windows:
|
345 |
+
rec_img, mask_img = run_model(model, x, mask_ratio, device)
|
346 |
+
rec_imgs.append(rec_img)
|
347 |
+
mask_imgs.append(mask_img)
|
348 |
+
|
349 |
+
rec_imgs = torch.concat(rec_imgs, dim=0)
|
350 |
+
mask_imgs = torch.concat(mask_imgs, dim=0)
|
351 |
+
|
352 |
+
# Build images from patches
|
353 |
+
rec_imgs = rearrange(rec_imgs, '(b h1 w1) c t h w -> b c t (h1 h) (w1 w)',
|
354 |
+
h=img_size, w=img_size, b=1, c=len(bands), t=num_frames, h1=h1, w1=w1)
|
355 |
+
mask_imgs = rearrange(mask_imgs, '(b h1 w1) c t h w -> b c t (h1 h) (w1 w)',
|
356 |
+
h=img_size, w=img_size, b=1, c=len(bands), t=num_frames, h1=h1, w1=w1)
|
357 |
+
|
358 |
+
# Cut padded images back to original size
|
359 |
+
rec_imgs_full = rec_imgs[..., :original_h, :original_w]
|
360 |
+
mask_imgs_full = mask_imgs[..., :original_h, :original_w]
|
361 |
+
batch_full = batch[..., :original_h, :original_w]
|
362 |
+
|
363 |
+
# Build output images
|
364 |
+
if rgb_outputs:
|
365 |
+
for d in meta_data:
|
366 |
+
d.update(count=3, dtype='uint8', compress='lzw', nodata=0)
|
367 |
+
|
368 |
+
save_rgb_imgs(batch_full[0, ...], rec_imgs_full[0, ...], mask_imgs_full[0, ...],
|
369 |
+
channels, mean, std, output_dir, meta_data)
|
370 |
+
else:
|
371 |
+
for d in meta_data:
|
372 |
+
d.update(compress='lzw', nodata=0)
|
373 |
+
|
374 |
+
save_imgs(rec_imgs_full[0, ...], mask_imgs_full[0, ...], mean, std, output_dir, meta_data)
|
375 |
+
|
376 |
+
print("Done!")
|
377 |
+
|
378 |
+
|
379 |
+
if __name__ == "__main__":
|
380 |
+
parser = argparse.ArgumentParser('MAE run inference', add_help=False)
|
381 |
+
|
382 |
+
parser.add_argument('--data_files', required=True, type=str, nargs='+',
|
383 |
+
help='Path to the data files. Assumes multi-band files.')
|
384 |
+
parser.add_argument('--yaml_file_path', type=str, required=True,
|
385 |
+
help='Path to yaml file containing model training parameters.')
|
386 |
+
parser.add_argument('--checkpoint', required=True, type=str,
|
387 |
+
help='Path to a checkpoint file to load from.')
|
388 |
+
parser.add_argument('--output_dir', required=True, type=str,
|
389 |
+
help='Path to the directory where to save outputs.')
|
390 |
+
parser.add_argument('--mask_ratio', default=None, type=float,
|
391 |
+
help='Masking ratio (percentage of removed patches). '
|
392 |
+
'If None (default) use same value used for pretraining.')
|
393 |
+
parser.add_argument('--rgb_outputs', action='store_true',
|
394 |
+
help='If present, output files will only contain RGB channels. '
|
395 |
+
'Otherwise, all bands will be saved.')
|
396 |
+
args = parser.parse_args()
|
397 |
+
|
398 |
+
main(**vars(args))
|
399 |
+
|
README.md
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
license: apache-2.0
|
3 |
+
tags:
|
4 |
+
- Pytorch
|
5 |
+
- Geospatial
|
6 |
+
- Temporal ViT
|
7 |
+
- Vit
|
8 |
+
---
|
9 |
+
|
10 |
+
|
11 |
+
### Code
|
12 |
+
The model follows the [original repo](https://huggingface.co/ibm-nasa-geospatial/Prithvi-100M)
|
13 |
+
I made simple modification by using original repo source code to visualize.
|
14 |
+
|
15 |
+
### Data
|
16 |
+
Area: South Korea Jeollanam-do
|
17 |
+
Sourced the image from this [link](https://search.earthdata.nasa.gov/search/granules?p=C2021957657-LPCLOUD&pg[0][v]=f&pg[0][qt]=2009-01-01T00%3A00%3A00.000Z%2C&pg[0][dnf]=DAY&pg[0][gsk]=-start_date&q=HLSL30&sb[0]=126.57129%2C34.87923%2C126.97998%2C35.09012&tl=1696429462!3!!&lat=33.70166015625&long=125.0771484375&zoom=7)
|
18 |
+
Google map location [link](https://www.google.co.kr/maps/place/34%C2%B052'45.2%22N+126%C2%B034'16.6%22E/data=!4m4!3m3!8m2!3d34.87923!4d126.57129?hl=ko&entry=ttu)
|
19 |
+
|
20 |
+
### Usecase
|
21 |
+
Here's a sample video:
|
22 |
+
|
23 |
+
![Sample Video](streamlit-testing.webm)
|
24 |
+
|
25 |
+
|
26 |
+
|
27 |
+
### Citation
|
28 |
+
|
29 |
+
Please cite original repository.
|
30 |
+
If this model helped your research, please cite `Prithvi-100M` in your publications. Here is an example BibTeX entry:
|
31 |
+
|
32 |
+
```
|
33 |
+
@misc{Prithvi-100M,
|
34 |
+
author = {Jakubik, Johannes and Chu, Linsong and Fraccaro, Paolo and Gomes, Carlos and Nyirjesy, Gabby and Bangalore, Ranjini and Lambhate, Devyani and Das, Kamal and Oliveira Borges, Dario and Kimura, Daiki and Simumba, Naomi and Szwarcman, Daniela and Muszynski, Michal and Weldemariam, Kommy and Zadrozny, Bianca and Ganti, Raghu and Costa, Carlos and Edwards, Blair & Watson, Campbell and Mukkavilli, Karthik and Schmude, Johannes & Hamann, Hendrik and Robert, Parkin and Roy, Sujit and Phillips, Christopher and Ankur, Kumar and Ramasubramanian, Muthukumaran and Gurung, Iksha and Leong, Wei Ji and Avery, Ryan and Ramachandran, Rahul and Maskey, Manil and Olofossen, Pontus and Fancher, Elizabeth and Lee, Tsengdar and Murphy, Kevin and Duffy, Dan and Little, Mike and Alemohammad, Hamed and Cecil, Michael and Li, Steve and Khallaghi, Sam and Godwin, Denys and Ahmadi, Maryam and Kordi, Fatemeh and Saux, Bertrand and Pastick, Neal and Doucette, Peter and Fleckenstein, Rylie and Luanga, Dalton and Corvin, Alex and Granger, Erwan},
|
35 |
+
doi = {10.57967/hf/0952},
|
36 |
+
month = aug,
|
37 |
+
title = {{Prithvi-100M}},
|
38 |
+
repository-code = {https://github.com/NASA-IMPACT/hls-foundation-os},
|
39 |
+
year = {2023}
|
40 |
+
}
|
41 |
+
```
|
app.py
ADDED
@@ -0,0 +1,483 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from huggingface_hub import hf_hub_download
|
2 |
+
import os
|
3 |
+
import functools
|
4 |
+
from typing import List
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
import yaml
|
8 |
+
from einops import rearrange
|
9 |
+
from Prithvi import MaskedAutoencoderViT
|
10 |
+
from functools import partial
|
11 |
+
|
12 |
+
import rasterio
|
13 |
+
from rasterio.merge import merge
|
14 |
+
from rasterio import Affine
|
15 |
+
from rasterio.warp import calculate_default_transform, reproject, Resampling
|
16 |
+
|
17 |
+
import streamlit as st
|
18 |
+
from streamlit_image_comparison import image_comparison
|
19 |
+
|
20 |
+
NO_DATA = -9999
|
21 |
+
NO_DATA_FLOAT = 0.0001
|
22 |
+
PERCENTILES = (0.1, 99.9)
|
23 |
+
|
24 |
+
TOKEN = "JONATHAN_TOKEN"
|
25 |
+
yaml_file_path=hf_hub_download(repo_id="ibm-nasa-geospatial/Prithvi-100M", filename="Prithvi_100M_config.yaml", token=TOKEN)
|
26 |
+
checkpoint=hf_hub_download(repo_id="ibm-nasa-geospatial/Prithvi-100M", filename='Prithvi_100M.pt', token=TOKEN)
|
27 |
+
model_def=hf_hub_download(repo_id="ibm-nasa-geospatial/Prithvi-100M", filename='Prithvi.py', token=TOKEN)
|
28 |
+
|
29 |
+
|
30 |
+
|
31 |
+
def process_channel_group(orig_img, new_img, channels, data_mean, data_std):
|
32 |
+
""" Process *orig_img* and *new_img* for RGB visualization. Each band is rescaled back to the
|
33 |
+
original range using *data_mean* and *data_std* and then lowest and highest percentiles are
|
34 |
+
removed to enhance contrast. Data is rescaled to (0, 1) range and stacked channels_first.
|
35 |
+
Args:
|
36 |
+
orig_img: torch.Tensor representing original image (reference) with shape = (bands, H, W).
|
37 |
+
new_img: torch.Tensor representing image with shape = (bands, H, W).
|
38 |
+
channels: list of indices representing RGB channels.
|
39 |
+
data_mean: list of mean values for each band.
|
40 |
+
data_std: list of std values for each band.
|
41 |
+
Returns:
|
42 |
+
torch.Tensor with shape (num_channels, height, width) for original image
|
43 |
+
torch.Tensor with shape (num_channels, height, width) for the other image
|
44 |
+
"""
|
45 |
+
|
46 |
+
stack_c = [], []
|
47 |
+
|
48 |
+
for c in channels:
|
49 |
+
orig_ch = orig_img[c, ...]
|
50 |
+
valid_mask = torch.ones_like(orig_ch, dtype=torch.bool)
|
51 |
+
valid_mask[orig_ch == NO_DATA_FLOAT] = False
|
52 |
+
|
53 |
+
# Back to original data range
|
54 |
+
orig_ch = (orig_ch * data_std[c]) + data_mean[c]
|
55 |
+
new_ch = (new_img[c, ...] * data_std[c]) + data_mean[c]
|
56 |
+
|
57 |
+
# Rescale (enhancing contrast)
|
58 |
+
min_value, max_value = np.percentile(orig_ch[valid_mask], PERCENTILES)
|
59 |
+
|
60 |
+
orig_ch = torch.clamp((orig_ch - min_value) / (max_value - min_value), 0, 1)
|
61 |
+
new_ch = torch.clamp((new_ch - min_value) / (max_value - min_value), 0, 1)
|
62 |
+
|
63 |
+
# No data as zeros
|
64 |
+
orig_ch[~valid_mask] = 0
|
65 |
+
new_ch[~valid_mask] = 0
|
66 |
+
|
67 |
+
stack_c[0].append(orig_ch)
|
68 |
+
stack_c[1].append(new_ch)
|
69 |
+
|
70 |
+
# Channels first
|
71 |
+
stack_orig = torch.stack(stack_c[0], dim=0)
|
72 |
+
stack_rec = torch.stack(stack_c[1], dim=0)
|
73 |
+
|
74 |
+
return stack_orig, stack_rec
|
75 |
+
|
76 |
+
|
77 |
+
def read_geotiff(file_path: str):
|
78 |
+
""" Read all bands from *file_path* and returns image + meta info.
|
79 |
+
Args:
|
80 |
+
file_path: path to image file.
|
81 |
+
Returns:
|
82 |
+
np.ndarray with shape (bands, height, width)
|
83 |
+
meta info dict
|
84 |
+
"""
|
85 |
+
|
86 |
+
with rasterio.open(file_path) as src:
|
87 |
+
img = src.read()
|
88 |
+
meta = src.meta
|
89 |
+
|
90 |
+
return img, meta
|
91 |
+
|
92 |
+
|
93 |
+
def save_geotiff(image, output_path: str, meta: dict):
|
94 |
+
""" Save multi-band image in Geotiff file.
|
95 |
+
Args:
|
96 |
+
image: np.ndarray with shape (bands, height, width)
|
97 |
+
output_path: path where to save the image
|
98 |
+
meta: dict with meta info.
|
99 |
+
"""
|
100 |
+
|
101 |
+
with rasterio.open(output_path, "w", **meta) as dest:
|
102 |
+
for i in range(image.shape[0]):
|
103 |
+
dest.write(image[i, :, :], i + 1)
|
104 |
+
|
105 |
+
return
|
106 |
+
|
107 |
+
|
108 |
+
def _convert_np_uint8(float_image: torch.Tensor):
|
109 |
+
|
110 |
+
image = float_image.numpy() * 255.0
|
111 |
+
image = image.astype(dtype=np.uint8)
|
112 |
+
image = image.transpose((1, 2, 0))
|
113 |
+
|
114 |
+
return image
|
115 |
+
|
116 |
+
|
117 |
+
def load_example(file_paths: List[str], mean: List[float], std: List[float]):
|
118 |
+
""" Build an input example by loading images in *file_paths*.
|
119 |
+
Args:
|
120 |
+
file_paths: list of file paths .
|
121 |
+
mean: list containing mean values for each band in the images in *file_paths*.
|
122 |
+
std: list containing std values for each band in the images in *file_paths*.
|
123 |
+
Returns:
|
124 |
+
np.array containing created example
|
125 |
+
list of meta info for each image in *file_paths*
|
126 |
+
"""
|
127 |
+
|
128 |
+
imgs = []
|
129 |
+
metas = []
|
130 |
+
|
131 |
+
for file in file_paths:
|
132 |
+
img, meta = read_geotiff(file)
|
133 |
+
img = img[:6]*10000 if img[:6].mean() <= 2 else img[:6]
|
134 |
+
|
135 |
+
# Rescaling (don't normalize on nodata)
|
136 |
+
img = np.moveaxis(img, 0, -1) # channels last for rescaling
|
137 |
+
img = np.where(img == NO_DATA, NO_DATA_FLOAT, (img - mean) / std)
|
138 |
+
|
139 |
+
imgs.append(img)
|
140 |
+
metas.append(meta)
|
141 |
+
|
142 |
+
imgs = np.stack(imgs, axis=0) # num_frames, H, W, C
|
143 |
+
imgs = np.moveaxis(imgs, -1, 0).astype('float32') # C, num_frames, H, W
|
144 |
+
imgs = np.expand_dims(imgs, axis=0) # add batch dim
|
145 |
+
|
146 |
+
return imgs, metas
|
147 |
+
|
148 |
+
|
149 |
+
def run_model(model: torch.nn.Module, input_data: torch.Tensor, mask_ratio: float, device: torch.device):
|
150 |
+
""" Run *model* with *input_data* and create images from output tokens (mask, reconstructed + visible).
|
151 |
+
Args:
|
152 |
+
model: MAE model to run.
|
153 |
+
input_data: torch.Tensor with shape (B, C, T, H, W).
|
154 |
+
mask_ratio: mask ratio to use.
|
155 |
+
device: device where model should run.
|
156 |
+
Returns:
|
157 |
+
3 torch.Tensor with shape (B, C, T, H, W).
|
158 |
+
"""
|
159 |
+
|
160 |
+
with torch.no_grad():
|
161 |
+
x = input_data.to(device)
|
162 |
+
|
163 |
+
_, pred, mask = model(x, mask_ratio)
|
164 |
+
|
165 |
+
# Create mask and prediction images (un-patchify)
|
166 |
+
mask_img = model.unpatchify(mask.unsqueeze(-1).repeat(1, 1, pred.shape[-1])).detach().cpu()
|
167 |
+
pred_img = model.unpatchify(pred).detach().cpu()
|
168 |
+
|
169 |
+
# Mix visible and predicted patches
|
170 |
+
rec_img = input_data.clone()
|
171 |
+
rec_img[mask_img == 1] = pred_img[mask_img == 1] # binary mask: 0 is keep, 1 is remove
|
172 |
+
|
173 |
+
# Switch zeros/ones in mask images so masked patches appear darker in plots (better visualization)
|
174 |
+
mask_img = (~(mask_img.to(torch.bool))).to(torch.float)
|
175 |
+
|
176 |
+
return rec_img, mask_img
|
177 |
+
|
178 |
+
|
179 |
+
def save_rgb_imgs(input_img, rec_img, mask_img, channels, mean, std, output_dir, meta_data):
|
180 |
+
""" Wrapper function to save Geotiff images (original, reconstructed, masked) per timestamp.
|
181 |
+
Args:
|
182 |
+
input_img: input torch.Tensor with shape (C, T, H, W).
|
183 |
+
rec_img: reconstructed torch.Tensor with shape (C, T, H, W).
|
184 |
+
mask_img: mask torch.Tensor with shape (C, T, H, W).
|
185 |
+
channels: list of indices representing RGB channels.
|
186 |
+
mean: list of mean values for each band.
|
187 |
+
std: list of std values for each band.
|
188 |
+
output_dir: directory where to save outputs.
|
189 |
+
meta_data: list of dicts with geotiff meta info.
|
190 |
+
"""
|
191 |
+
|
192 |
+
for t in range(input_img.shape[1]):
|
193 |
+
rgb_orig, rgb_pred = process_channel_group(orig_img=input_img[:, t, :, :],
|
194 |
+
new_img=rec_img[:, t, :, :],
|
195 |
+
channels=channels, data_mean=mean,
|
196 |
+
data_std=std)
|
197 |
+
|
198 |
+
rgb_mask = mask_img[channels, t, :, :] * rgb_orig
|
199 |
+
|
200 |
+
# Saving images
|
201 |
+
|
202 |
+
save_geotiff(image=_convert_np_uint8(rgb_orig),
|
203 |
+
output_path=os.path.join(output_dir, f"original_rgb_t{t}.tiff"),
|
204 |
+
meta=meta_data[t])
|
205 |
+
|
206 |
+
save_geotiff(image=_convert_np_uint8(rgb_pred),
|
207 |
+
output_path=os.path.join(output_dir, f"predicted_rgb_t{t}.tiff"),
|
208 |
+
meta=meta_data[t])
|
209 |
+
|
210 |
+
save_geotiff(image=_convert_np_uint8(rgb_mask),
|
211 |
+
output_path=os.path.join(output_dir, f"masked_rgb_t{t}.tiff"),
|
212 |
+
meta=meta_data[t])
|
213 |
+
|
214 |
+
|
215 |
+
def extract_rgb_imgs(input_img, rec_img, mask_img, channels, mean, std):
|
216 |
+
""" Wrapper function to save Geotiff images (original, reconstructed, masked) per timestamp.
|
217 |
+
Args:
|
218 |
+
input_img: input torch.Tensor with shape (C, T, H, W).
|
219 |
+
rec_img: reconstructed torch.Tensor with shape (C, T, H, W).
|
220 |
+
mask_img: mask torch.Tensor with shape (C, T, H, W).
|
221 |
+
channels: list of indices representing RGB channels.
|
222 |
+
mean: list of mean values for each band.
|
223 |
+
std: list of std values for each band.
|
224 |
+
output_dir: directory where to save outputs.
|
225 |
+
meta_data: list of dicts with geotiff meta info.
|
226 |
+
"""
|
227 |
+
rgb_orig_list = []
|
228 |
+
rgb_mask_list = []
|
229 |
+
rgb_pred_list = []
|
230 |
+
|
231 |
+
for t in range(input_img.shape[1]):
|
232 |
+
rgb_orig, rgb_pred = process_channel_group(orig_img=input_img[:, t, :, :],
|
233 |
+
new_img=rec_img[:, t, :, :],
|
234 |
+
channels=channels, data_mean=mean,
|
235 |
+
data_std=std)
|
236 |
+
|
237 |
+
rgb_mask = mask_img[channels, t, :, :] * rgb_orig
|
238 |
+
|
239 |
+
# extract images
|
240 |
+
rgb_orig_list.append(_convert_np_uint8(rgb_orig))
|
241 |
+
rgb_mask_list.append(_convert_np_uint8(rgb_mask))
|
242 |
+
rgb_pred_list.append(_convert_np_uint8(rgb_pred))
|
243 |
+
|
244 |
+
outputs = rgb_orig_list + rgb_mask_list + rgb_pred_list
|
245 |
+
|
246 |
+
return outputs
|
247 |
+
|
248 |
+
|
249 |
+
def predict_on_images(data_files: list, mask_ratio: float, yaml_file_path: str, checkpoint: str):
|
250 |
+
|
251 |
+
|
252 |
+
try:
|
253 |
+
data_files = [x.name for x in data_files]
|
254 |
+
print('Path extracted from example')
|
255 |
+
except:
|
256 |
+
print('Files submitted through UI')
|
257 |
+
|
258 |
+
# Get parameters --------
|
259 |
+
print('This is the printout', data_files)
|
260 |
+
|
261 |
+
with open(yaml_file_path, 'r') as f:
|
262 |
+
params = yaml.safe_load(f)
|
263 |
+
|
264 |
+
# data related
|
265 |
+
num_frames = params['num_frames']
|
266 |
+
img_size = params['img_size']
|
267 |
+
bands = params['bands']
|
268 |
+
mean = params['data_mean']
|
269 |
+
std = params['data_std']
|
270 |
+
|
271 |
+
# model related
|
272 |
+
depth = params['depth']
|
273 |
+
patch_size = params['patch_size']
|
274 |
+
embed_dim = params['embed_dim']
|
275 |
+
num_heads = params['num_heads']
|
276 |
+
tubelet_size = params['tubelet_size']
|
277 |
+
decoder_embed_dim = params['decoder_embed_dim']
|
278 |
+
decoder_num_heads = params['decoder_num_heads']
|
279 |
+
decoder_depth = params['decoder_depth']
|
280 |
+
|
281 |
+
batch_size = params['batch_size']
|
282 |
+
|
283 |
+
mask_ratio = params['mask_ratio'] if mask_ratio is None else mask_ratio
|
284 |
+
|
285 |
+
# We must have *num_frames* files to build one example!
|
286 |
+
assert len(data_files) == num_frames, "File list must be equal to expected number of frames."
|
287 |
+
|
288 |
+
if torch.cuda.is_available():
|
289 |
+
device = torch.device('cuda')
|
290 |
+
else:
|
291 |
+
device = torch.device('cpu')
|
292 |
+
|
293 |
+
print(f"Using {device} device.\n")
|
294 |
+
|
295 |
+
# Loading data ---------------------------------------------------------------------------------
|
296 |
+
|
297 |
+
input_data, meta_data = load_example(file_paths=data_files, mean=mean, std=std)
|
298 |
+
|
299 |
+
# Create model and load checkpoint -------------------------------------------------------------
|
300 |
+
|
301 |
+
model = MaskedAutoencoderViT(
|
302 |
+
img_size=img_size,
|
303 |
+
patch_size=patch_size,
|
304 |
+
num_frames=num_frames,
|
305 |
+
tubelet_size=tubelet_size,
|
306 |
+
in_chans=len(bands),
|
307 |
+
embed_dim=embed_dim,
|
308 |
+
depth=depth,
|
309 |
+
num_heads=num_heads,
|
310 |
+
decoder_embed_dim=decoder_embed_dim,
|
311 |
+
decoder_depth=decoder_depth,
|
312 |
+
decoder_num_heads=decoder_num_heads,
|
313 |
+
mlp_ratio=4.,
|
314 |
+
norm_layer=functools.partial(torch.nn.LayerNorm, eps=1e-6),
|
315 |
+
norm_pix_loss=False)
|
316 |
+
|
317 |
+
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
318 |
+
print(f"\n--> Model has {total_params:,} parameters.\n")
|
319 |
+
|
320 |
+
model.to(device)
|
321 |
+
|
322 |
+
state_dict = torch.load(checkpoint, map_location=device)
|
323 |
+
model.load_state_dict(state_dict)
|
324 |
+
print(f"Loaded checkpoint from {checkpoint}")
|
325 |
+
|
326 |
+
# Running model --------------------------------------------------------------------------------
|
327 |
+
|
328 |
+
model.eval()
|
329 |
+
channels = [bands.index(b) for b in ['B04', 'B03', 'B02']] # BGR -> RGB
|
330 |
+
|
331 |
+
# Reflect pad if not divisible by img_size
|
332 |
+
original_h, original_w = input_data.shape[-2:]
|
333 |
+
pad_h = img_size - (original_h % img_size)
|
334 |
+
pad_w = img_size - (original_w % img_size)
|
335 |
+
input_data = np.pad(input_data, ((0, 0), (0, 0), (0, 0), (0, pad_h), (0, pad_w)), mode='reflect')
|
336 |
+
|
337 |
+
# Build sliding window
|
338 |
+
batch = torch.tensor(input_data, device='cpu')
|
339 |
+
windows = batch.unfold(3, img_size, img_size).unfold(4, img_size, img_size)
|
340 |
+
h1, w1 = windows.shape[3:5]
|
341 |
+
windows = rearrange(windows, 'b c t h1 w1 h w -> (b h1 w1) c t h w', h=img_size, w=img_size)
|
342 |
+
|
343 |
+
# Split into batches if number of windows > batch_size
|
344 |
+
num_batches = windows.shape[0] // batch_size if windows.shape[0] > batch_size else 1
|
345 |
+
windows = torch.tensor_split(windows, num_batches, dim=0)
|
346 |
+
|
347 |
+
# Run model
|
348 |
+
rec_imgs = []
|
349 |
+
mask_imgs = []
|
350 |
+
for x in windows:
|
351 |
+
rec_img, mask_img = run_model(model, x, mask_ratio, device)
|
352 |
+
rec_imgs.append(rec_img)
|
353 |
+
mask_imgs.append(mask_img)
|
354 |
+
|
355 |
+
rec_imgs = torch.concat(rec_imgs, dim=0)
|
356 |
+
mask_imgs = torch.concat(mask_imgs, dim=0)
|
357 |
+
|
358 |
+
# Build images from patches
|
359 |
+
rec_imgs = rearrange(rec_imgs, '(b h1 w1) c t h w -> b c t (h1 h) (w1 w)',
|
360 |
+
h=img_size, w=img_size, b=1, c=len(bands), t=num_frames, h1=h1, w1=w1)
|
361 |
+
mask_imgs = rearrange(mask_imgs, '(b h1 w1) c t h w -> b c t (h1 h) (w1 w)',
|
362 |
+
h=img_size, w=img_size, b=1, c=len(bands), t=num_frames, h1=h1, w1=w1)
|
363 |
+
|
364 |
+
# Cut padded images back to original size
|
365 |
+
rec_imgs_full = rec_imgs[..., :original_h, :original_w]
|
366 |
+
mask_imgs_full = mask_imgs[..., :original_h, :original_w]
|
367 |
+
batch_full = batch[..., :original_h, :original_w]
|
368 |
+
|
369 |
+
# Build RGB images
|
370 |
+
for d in meta_data:
|
371 |
+
d.update(count=3, dtype='uint8', compress='lzw', nodata=0)
|
372 |
+
|
373 |
+
# save_rgb_imgs(batch[0, ...], rec_imgs_full[0, ...], mask_imgs_full[0, ...],
|
374 |
+
# channels, mean, std, output_dir, meta_data)
|
375 |
+
|
376 |
+
outputs = extract_rgb_imgs(batch_full[0, ...], rec_imgs_full[0, ...], mask_imgs_full[0, ...],
|
377 |
+
channels, mean, std)
|
378 |
+
|
379 |
+
|
380 |
+
print("Done!")
|
381 |
+
|
382 |
+
return outputs
|
383 |
+
|
384 |
+
# partial function prep
|
385 |
+
func = partial(predict_on_images, yaml_file_path=yaml_file_path,checkpoint=checkpoint,mask_ratio=0.75)
|
386 |
+
|
387 |
+
|
388 |
+
|
389 |
+
|
390 |
+
## South Korea rural area HSL landset merger (from B02 channel to B07 channel)
|
391 |
+
|
392 |
+
def raw_NASA_tif_file_merger(PSD_NAME,tif_files):
|
393 |
+
src_files_to_mosaic = []
|
394 |
+
for tif_file in tif_files:
|
395 |
+
src = rasterio.open(tif_file)
|
396 |
+
src_files_to_mosaic.append(src)
|
397 |
+
mosaic, out_trans = merge(src_files_to_mosaic)
|
398 |
+
out_meta = src.meta.copy()
|
399 |
+
out_meta.update({"driver": "GTiff",
|
400 |
+
"height": mosaic.shape[1],
|
401 |
+
"width": mosaic.shape[2],
|
402 |
+
"transform": out_trans})
|
403 |
+
|
404 |
+
|
405 |
+
with rasterio.open(PSD_NAME, "w", **out_meta) as dest:
|
406 |
+
dest.write(mosaic)
|
407 |
+
|
408 |
+
# raw_NASA_tif_file_merger("third.tif","./1/*.tif")
|
409 |
+
|
410 |
+
|
411 |
+
|
412 |
+
# streamlit area
|
413 |
+
def main_loop():
|
414 |
+
st.title("HuggingFace Inference Demo")
|
415 |
+
st.subheader("Be sure to set the parameter")
|
416 |
+
|
417 |
+
[out1_orig_t1,out2_orig_t2,out3_orig_t3,out4_masked_t1,out5_masked_t2,out6_masked_t3,out7_pred_t1,out8_pred_t2,out9_pred_t3]=func(["first.tif","second.tif","third.tif"])
|
418 |
+
|
419 |
+
|
420 |
+
st.markdown("### first original image and masked image comparison")
|
421 |
+
image_comparison(
|
422 |
+
img1=out1_orig_t1,
|
423 |
+
img2=out4_masked_t1,
|
424 |
+
label1="original-1",
|
425 |
+
label2="masked-1",
|
426 |
+
width=1024,
|
427 |
+
)
|
428 |
+
|
429 |
+
|
430 |
+
st.markdown("### second original image and masked image comparison")
|
431 |
+
image_comparison(
|
432 |
+
img1=out2_orig_t2,
|
433 |
+
img2=out5_masked_t2,
|
434 |
+
label1="original-2",
|
435 |
+
label2="masked-2",
|
436 |
+
width=1024,
|
437 |
+
)
|
438 |
+
|
439 |
+
|
440 |
+
st.markdown("### thrid original image and masked image comparison")
|
441 |
+
image_comparison(
|
442 |
+
img1=out3_orig_t3,
|
443 |
+
img2=out6_masked_t3,
|
444 |
+
label1="original-1",
|
445 |
+
label2="masked-1",
|
446 |
+
width=1024,
|
447 |
+
)
|
448 |
+
|
449 |
+
|
450 |
+
|
451 |
+
st.markdown("### first original image and encoded image comparison")
|
452 |
+
image_comparison(
|
453 |
+
img1=out1_orig_t1,
|
454 |
+
img2=out7_pred_t1,
|
455 |
+
label1="original-1",
|
456 |
+
label2="masked-1",
|
457 |
+
width=1024,
|
458 |
+
)
|
459 |
+
|
460 |
+
|
461 |
+
st.markdown("### second original image and encoded image comparison")
|
462 |
+
image_comparison(
|
463 |
+
img1=out2_orig_t2,
|
464 |
+
img2=out8_pred_t2,
|
465 |
+
label1="original-2",
|
466 |
+
label2="masked-2",
|
467 |
+
width=1024,
|
468 |
+
)
|
469 |
+
|
470 |
+
|
471 |
+
st.markdown("### thrid original image and encoded image comparison")
|
472 |
+
image_comparison(
|
473 |
+
img1=out3_orig_t3,
|
474 |
+
img2=out9_pred_t3,
|
475 |
+
label1="original-1",
|
476 |
+
label2="masked-1",
|
477 |
+
width=1024,
|
478 |
+
)
|
479 |
+
|
480 |
+
if __name__ == '__main__':
|
481 |
+
main_loop()
|
482 |
+
|
483 |
+
|
first.tif
ADDED
Git LFS Details
|
requirements.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
torchvision
|
3 |
+
timm
|
4 |
+
einops
|
5 |
+
rasterio
|
second.tif
ADDED
Git LFS Details
|
streamlit-testing.webm
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:590a810b981ac32d78a04f65ac78b4f2c15d77e465b95b659c1e1be83181e94a
|
3 |
+
size 10384274
|
temp/1ab7b057-9543-4240-92a1-f85bba853af6.jpg
ADDED
Git LFS Details
|
temp/771eca76-489a-41c7-bcfb-28b841e78dd7.jpg
ADDED
Git LFS Details
|
third.tif
ADDED
Git LFS Details
|