Paolo-Fraccaro
commited on
Commit
·
f62a54e
1
Parent(s):
ff3ecff
add padding
Browse files
app.py
CHANGED
@@ -47,7 +47,7 @@ def process_channel_group(orig_img, new_img, channels, data_mean, data_std):
|
|
47 |
for c in channels:
|
48 |
orig_ch = orig_img[c, ...]
|
49 |
valid_mask = torch.ones_like(orig_ch, dtype=torch.bool)
|
50 |
-
valid_mask[orig_ch ==
|
51 |
|
52 |
# Back to original data range
|
53 |
orig_ch = (orig_ch * data_std[c]) + data_mean[c]
|
@@ -138,8 +138,8 @@ def load_example(file_paths: List[str], mean: List[float], std: List[float]):
|
|
138 |
imgs.append(img)
|
139 |
metas.append(meta)
|
140 |
|
141 |
-
imgs = np.stack(imgs, axis=0) # num_frames,
|
142 |
-
imgs = np.moveaxis(imgs, -1, 0).astype('float32') # C, num_frames,
|
143 |
imgs = np.expand_dims(imgs, axis=0) # add batch dim
|
144 |
|
145 |
return imgs, metas
|
@@ -308,7 +308,7 @@ def predict_on_images(data_files: list, mask_ratio: float, yaml_file_path: str,
|
|
308 |
norm_pix_loss=False)
|
309 |
|
310 |
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
311 |
-
print(f"\n-->
|
312 |
|
313 |
model.to(device)
|
314 |
|
@@ -320,6 +320,12 @@ def predict_on_images(data_files: list, mask_ratio: float, yaml_file_path: str,
|
|
320 |
|
321 |
model.eval()
|
322 |
channels = [bands.index(b) for b in ['B04', 'B03', 'B02']] # BGR -> RGB
|
|
|
|
|
|
|
|
|
|
|
|
|
323 |
|
324 |
# Build sliding window
|
325 |
batch = torch.tensor(input_data, device='cpu')
|
@@ -348,13 +354,10 @@ def predict_on_images(data_files: list, mask_ratio: float, yaml_file_path: str,
|
|
348 |
mask_imgs = rearrange(mask_imgs, '(b h1 w1) c t h w -> b c t (h1 h) (w1 w)',
|
349 |
h=img_size, w=img_size, b=1, c=len(bands), t=num_frames, h1=h1, w1=w1)
|
350 |
|
351 |
-
#
|
352 |
-
|
353 |
-
|
354 |
-
|
355 |
-
|
356 |
-
mask_imgs_full = torch.ones_like(batch)
|
357 |
-
mask_imgs_full[..., :h, :w] = mask_imgs
|
358 |
|
359 |
# Build RGB images
|
360 |
for d in meta_data:
|
@@ -363,7 +366,7 @@ def predict_on_images(data_files: list, mask_ratio: float, yaml_file_path: str,
|
|
363 |
# save_rgb_imgs(batch[0, ...], rec_imgs_full[0, ...], mask_imgs_full[0, ...],
|
364 |
# channels, mean, std, output_dir, meta_data)
|
365 |
|
366 |
-
outputs = extract_rgb_imgs(
|
367 |
channels, mean, std)
|
368 |
|
369 |
|
|
|
47 |
for c in channels:
|
48 |
orig_ch = orig_img[c, ...]
|
49 |
valid_mask = torch.ones_like(orig_ch, dtype=torch.bool)
|
50 |
+
valid_mask[orig_ch == NO_DATA_FLOAT] = False
|
51 |
|
52 |
# Back to original data range
|
53 |
orig_ch = (orig_ch * data_std[c]) + data_mean[c]
|
|
|
138 |
imgs.append(img)
|
139 |
metas.append(meta)
|
140 |
|
141 |
+
imgs = np.stack(imgs, axis=0) # num_frames, H, W, C
|
142 |
+
imgs = np.moveaxis(imgs, -1, 0).astype('float32') # C, num_frames, H, W
|
143 |
imgs = np.expand_dims(imgs, axis=0) # add batch dim
|
144 |
|
145 |
return imgs, metas
|
|
|
308 |
norm_pix_loss=False)
|
309 |
|
310 |
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
311 |
+
print(f"\n--> Model has {total_params:,} parameters.\n")
|
312 |
|
313 |
model.to(device)
|
314 |
|
|
|
320 |
|
321 |
model.eval()
|
322 |
channels = [bands.index(b) for b in ['B04', 'B03', 'B02']] # BGR -> RGB
|
323 |
+
|
324 |
+
# Reflect pad if not divisible by img_size
|
325 |
+
original_h, original_w = input_data.shape[-2:]
|
326 |
+
pad_h = img_size - (original_h % img_size)
|
327 |
+
pad_w = img_size - (original_w % img_size)
|
328 |
+
input_data = np.pad(input_data, ((0, 0), (0, 0), (0, 0), (0, pad_h), (0, pad_w)), mode='reflect')
|
329 |
|
330 |
# Build sliding window
|
331 |
batch = torch.tensor(input_data, device='cpu')
|
|
|
354 |
mask_imgs = rearrange(mask_imgs, '(b h1 w1) c t h w -> b c t (h1 h) (w1 w)',
|
355 |
h=img_size, w=img_size, b=1, c=len(bands), t=num_frames, h1=h1, w1=w1)
|
356 |
|
357 |
+
# Cut padded images back to original size
|
358 |
+
rec_imgs_full = rec_imgs[..., :original_h, :original_w]
|
359 |
+
mask_imgs_full = mask_imgs[..., :original_h, :original_w]
|
360 |
+
batch_full = batch[..., :original_h, :original_w]
|
|
|
|
|
|
|
361 |
|
362 |
# Build RGB images
|
363 |
for d in meta_data:
|
|
|
366 |
# save_rgb_imgs(batch[0, ...], rec_imgs_full[0, ...], mask_imgs_full[0, ...],
|
367 |
# channels, mean, std, output_dir, meta_data)
|
368 |
|
369 |
+
outputs = extract_rgb_imgs(batch_full[0, ...], rec_imgs_full[0, ...], mask_imgs_full[0, ...],
|
370 |
channels, mean, std)
|
371 |
|
372 |
|