|
|
|
import os |
|
import torch |
|
import yaml |
|
import numpy as np |
|
import gradio as gr |
|
from einops import rearrange |
|
from functools import partial |
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
token = os.environ.get("HF_TOKEN", None) |
|
config_path = hf_hub_download(repo_id="ibm-nasa-geospatial/Prithvi-EO-1.0-100M", |
|
filename="config.json", token=token) |
|
checkpoint = hf_hub_download(repo_id="ibm-nasa-geospatial/Prithvi-EO-1.0-100M", |
|
filename='Prithvi_EO_V1_100M.pt', token=token) |
|
model_def = hf_hub_download(repo_id="ibm-nasa-geospatial/Prithvi-EO-1.0-100M", |
|
filename='prithvi_mae.py', token=token) |
|
model_inference = hf_hub_download(repo_id="ibm-nasa-geospatial/Prithvi-EO-1.0-100M", |
|
filename='inference.py', token=token) |
|
os.system(f'cp {model_def} .') |
|
os.system(f'cp {model_inference} .') |
|
|
|
from prithvi_mae import PrithviMAE |
|
from inference import process_channel_group, _convert_np_uint8, load_example, run_model |
|
|
|
def extract_rgb_imgs(input_img, rec_img, mask_img, channels, mean, std): |
|
""" Wrapper function to save Geotiff images (original, reconstructed, masked) per timestamp. |
|
Args: |
|
input_img: input torch.Tensor with shape (C, T, H, W). |
|
rec_img: reconstructed torch.Tensor with shape (C, T, H, W). |
|
mask_img: mask torch.Tensor with shape (C, T, H, W). |
|
channels: list of indices representing RGB channels. |
|
mean: list of mean values for each band. |
|
std: list of std values for each band. |
|
output_dir: directory where to save outputs. |
|
meta_data: list of dicts with geotiff meta info. |
|
""" |
|
rgb_orig_list = [] |
|
rgb_mask_list = [] |
|
rgb_pred_list = [] |
|
|
|
for t in range(input_img.shape[1]): |
|
rgb_orig, rgb_pred = process_channel_group(orig_img=input_img[:, t, :, :], |
|
new_img=rec_img[:, t, :, :], |
|
channels=channels, |
|
mean=mean, |
|
std=std) |
|
|
|
rgb_mask = mask_img[channels, t, :, :] * rgb_orig |
|
|
|
|
|
rgb_orig_list.append(_convert_np_uint8(rgb_orig).transpose(1, 2, 0)) |
|
rgb_mask_list.append(_convert_np_uint8(rgb_mask).transpose(1, 2, 0)) |
|
rgb_pred_list.append(_convert_np_uint8(rgb_pred).transpose(1, 2, 0)) |
|
|
|
|
|
dummy = np.ones((20, 20), dtype=np.uint8) * 255 |
|
num_dummies = 3 - len(rgb_orig_list) |
|
if num_dummies: |
|
rgb_orig_list.extend([dummy] * num_dummies) |
|
rgb_mask_list.extend([dummy] * num_dummies) |
|
rgb_pred_list.extend([dummy] * num_dummies) |
|
|
|
outputs = rgb_orig_list + rgb_mask_list + rgb_pred_list |
|
|
|
return outputs |
|
|
|
|
|
def predict_on_images(data_files: list, config_path: str, checkpoint: str, mask_ratio: float = None): |
|
try: |
|
data_files = [x.name for x in data_files] |
|
print('Path extracted from example') |
|
except: |
|
print('Files submitted through UI') |
|
|
|
|
|
print('This is the printout', data_files) |
|
|
|
with open(config_path, 'r') as f: |
|
config = yaml.safe_load(f)['pretrained_cfg'] |
|
|
|
batch_size = 8 |
|
bands = config['bands'] |
|
num_frames = len(data_files) |
|
mean = config['mean'] |
|
std = config['std'] |
|
img_size = config['img_size'] |
|
mask_ratio = mask_ratio or config['mask_ratio'] |
|
|
|
assert num_frames <= 3, "Demo only supports up to three timestamps" |
|
|
|
if torch.cuda.is_available(): |
|
device = torch.device('cuda') |
|
else: |
|
device = torch.device('cpu') |
|
|
|
print(f"Using {device} device.\n") |
|
|
|
|
|
|
|
input_data, meta_data = load_example(file_paths=data_files, mean=mean, std=std) |
|
|
|
|
|
|
|
config.update( |
|
num_frames=num_frames, |
|
) |
|
|
|
model = PrithviMAE(**config) |
|
|
|
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) |
|
print(f"\n--> Model has {total_params:,} parameters.\n") |
|
|
|
model.to(device) |
|
|
|
state_dict = torch.load(checkpoint, map_location=device, weights_only=False) |
|
|
|
for k in list(state_dict.keys()): |
|
if 'pos_embed' in k: |
|
del state_dict[k] |
|
model.load_state_dict(state_dict, strict=False) |
|
print(f"Loaded checkpoint from {checkpoint}") |
|
|
|
|
|
|
|
model.eval() |
|
channels = [bands.index(b) for b in ['B04', 'B03', 'B02']] |
|
|
|
|
|
original_h, original_w = input_data.shape[-2:] |
|
pad_h = img_size - (original_h % img_size) |
|
pad_w = img_size - (original_w % img_size) |
|
input_data = np.pad(input_data, ((0, 0), (0, 0), (0, 0), (0, pad_h), (0, pad_w)), mode='reflect') |
|
|
|
|
|
batch = torch.tensor(input_data, device='cpu') |
|
windows = batch.unfold(3, img_size, img_size).unfold(4, img_size, img_size) |
|
h1, w1 = windows.shape[3:5] |
|
windows = rearrange(windows, 'b c t h1 w1 h w -> (b h1 w1) c t h w', h=img_size, w=img_size) |
|
|
|
|
|
num_batches = windows.shape[0] // batch_size if windows.shape[0] > batch_size else 1 |
|
windows = torch.tensor_split(windows, num_batches, dim=0) |
|
|
|
|
|
rec_imgs = [] |
|
mask_imgs = [] |
|
for x in windows: |
|
rec_img, mask_img = run_model(model, x, mask_ratio, device) |
|
rec_imgs.append(rec_img) |
|
mask_imgs.append(mask_img) |
|
|
|
rec_imgs = torch.concat(rec_imgs, dim=0) |
|
mask_imgs = torch.concat(mask_imgs, dim=0) |
|
|
|
|
|
rec_imgs = rearrange(rec_imgs, '(b h1 w1) c t h w -> b c t (h1 h) (w1 w)', |
|
h=img_size, w=img_size, b=1, c=len(bands), t=num_frames, h1=h1, w1=w1) |
|
mask_imgs = rearrange(mask_imgs, '(b h1 w1) c t h w -> b c t (h1 h) (w1 w)', |
|
h=img_size, w=img_size, b=1, c=len(bands), t=num_frames, h1=h1, w1=w1) |
|
|
|
|
|
rec_imgs_full = rec_imgs[..., :original_h, :original_w] |
|
mask_imgs_full = mask_imgs[..., :original_h, :original_w] |
|
batch_full = batch[..., :original_h, :original_w] |
|
|
|
|
|
for d in meta_data: |
|
d.update(count=3, dtype='uint8', compress='lzw', nodata=0) |
|
|
|
outputs = extract_rgb_imgs(batch_full[0, ...], rec_imgs_full[0, ...], mask_imgs_full[0, ...], |
|
channels, mean, std) |
|
|
|
print("Done!") |
|
|
|
return outputs |
|
|
|
|
|
run_inference = partial(predict_on_images, config_path=config_path,checkpoint=checkpoint) |
|
|
|
with gr.Blocks() as demo: |
|
|
|
gr.Markdown(value='# Prithvi-EO-1.0 image reconstruction demo') |
|
gr.Markdown(value=''' |
|
Check out our newest model: [Prithvi-EO-2.0-Demo](https://huggingface.co/spaces/ibm-nasa-geospatial/Prithvi-EO-2.0-Demo). |
|
|
|
Prithvi is a first-of-its-kind temporal Vision transformer pretrained by the IBM and NASA team on continental US Harmonised Landsat Sentinel 2 (HLS) data. |
|
Particularly, the model adopts a self-supervised encoder developed with a ViT architecture and Masked AutoEncoder learning strategy, with a MSE as a loss function. |
|
The model includes spatial attention across multiple patchies and also temporal attention for each patch. |
|
More info about the model and its weights are available [here](https://huggingface.co/ibm-nasa-geospatial/Prithvi-100M).\n |
|
|
|
This demo showcases the image reconstruction over one to three timestamps. |
|
The model randomly masks out some proportion of the images and reconstructs them based on the not masked portion of the images. |
|
The reconstructed images are merged with the visible unmasked patches. |
|
We recommend submitting images of size 224 to ~1000 pixels for faster processing time. |
|
Images bigger than 224x224 are processed using a sliding window approach which can lead to artefacts between patches.\n |
|
|
|
The user needs to provide the HLS geotiff images, including the following channels in reflectance units: Blue, Green, Red, Narrow NIR, SWIR, SWIR 2. |
|
Some example images are provided at the end of this page. |
|
''') |
|
with gr.Row(): |
|
with gr.Column(): |
|
inp_files = gr.Files(elem_id='files') |
|
|
|
btn = gr.Button("Submit") |
|
with gr.Row(): |
|
gr.Markdown(value='## Input time series') |
|
gr.Markdown(value='## Masked images') |
|
gr.Markdown(value='## Reconstructed images*') |
|
|
|
original = [] |
|
masked = [] |
|
predicted = [] |
|
timestamps = [] |
|
for t in range(3): |
|
timestamps.append(gr.Column(visible=t == 0)) |
|
with timestamps[t]: |
|
|
|
|
|
with gr.Row(): |
|
original.append(gr.Image(image_mode='RGB', show_label=False, show_fullscreen_button=False)) |
|
masked.append(gr.Image(image_mode='RGB', show_label=False, show_fullscreen_button=False)) |
|
predicted.append(gr.Image(image_mode='RGB', show_label=False, show_fullscreen_button=False)) |
|
|
|
gr.Markdown(value='\* The reconstructed images include the ground truth unmasked patches.') |
|
|
|
btn.click(fn=run_inference, |
|
inputs=inp_files, |
|
outputs=original + masked + predicted) |
|
|
|
with gr.Row(): |
|
gr.Examples(examples=[[[ |
|
os.path.join(os.path.dirname(__file__), "examples/HLS.L30.T13REN.2018013T172747.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif"), |
|
os.path.join(os.path.dirname(__file__), "examples/HLS.L30.T13REN.2018029T172738.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif"), |
|
os.path.join(os.path.dirname(__file__), "examples/HLS.L30.T13REN.2018061T172724.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif") |
|
]],[[ |
|
os.path.join(os.path.dirname(__file__), "examples/HLS.L30.T17RMP.2018004T155509.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif"), |
|
os.path.join(os.path.dirname(__file__), "examples/HLS.L30.T17RMP.2018036T155452.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif"), |
|
os.path.join(os.path.dirname(__file__), "examples/HLS.L30.T17RMP.2018068T155438.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif") |
|
]],[[ |
|
os.path.join(os.path.dirname(__file__), "examples/HLS.L30.T18TVL.2018029T154533.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif"), |
|
os.path.join(os.path.dirname(__file__), "examples/HLS.L30.T18TVL.2018141T154435.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif"), |
|
os.path.join(os.path.dirname(__file__), "examples/HLS.L30.T18TVL.2018189T154446.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif") |
|
]]], |
|
inputs=inp_files, |
|
outputs=original + masked + predicted, |
|
fn=run_inference, |
|
cache_examples=True |
|
) |
|
|
|
def update_visibility(files): |
|
timestamps = [gr.Column(visible=t < len(files)) for t in range(3)] |
|
|
|
return timestamps |
|
|
|
inp_files.change(update_visibility, inp_files, timestamps) |
|
|
|
demo.launch(share=True, ssr_mode=False) |
|
|