import os import torch import yaml import numpy as np import gradio as gr from pathlib import Path from einops import rearrange from functools import partial from huggingface_hub import hf_hub_download from terratorch.cli_tools import LightningInferenceModel # pull files from hub- token = os.environ.get("HF_TOKEN", None) config_path = hf_hub_download(repo_id="ibm-nasa-geospatial/Prithvi-EO-2.0-300M-BurnScars", filename="burn_scars_config.yaml", token=token) checkpoint = hf_hub_download(repo_id="ibm-nasa-geospatial/Prithvi-EO-2.0-300M-BurnScars", filename='Prithvi_EO_V2_300M_BurnScars.pt', token=token) model_inference = hf_hub_download(repo_id="ibm-nasa-geospatial/Prithvi-EO-2.0-300M-BurnScars", filename='inference.py', token=token) os.system(f'cp {model_inference} .') from inference import process_channel_group, _convert_np_uint8, load_example, run_model def extract_rgb_imgs(input_img, pred_img, channels): """ Wrapper function to save Geotiff images (original, reconstructed, masked) per timestamp. Args: input_img: input torch.Tensor with shape (C, H, W). rec_img: reconstructed torch.Tensor with shape (C, T, H, W). pred_img: mask torch.Tensor with shape (C, T, H, W). channels: list of indices representing RGB channels. mean: list of mean values for each band. std: list of std values for each band. output_dir: directory where to save outputs. meta_data: list of dicts with geotiff meta info. """ rgb_orig_list = [] rgb_mask_list = [] rgb_pred_list = [] for t in range(input_img.shape[1]): rgb_orig, rgb_pred = process_channel_group(orig_img=input_img[:, t, :, :], new_img=rec_img[:, t, :, :], channels=channels, mean=mean, std=std) rgb_mask = mask_img[channels, t, :, :] * rgb_orig # extract images rgb_orig_list.append(_convert_np_uint8(rgb_orig).transpose(1, 2, 0)) rgb_mask_list.append(_convert_np_uint8(rgb_mask).transpose(1, 2, 0)) rgb_pred_list.append(_convert_np_uint8(rgb_pred).transpose(1, 2, 0)) # Add white dummy image values for missing timestamps dummy = np.ones((20, 20), dtype=np.uint8) * 255 num_dummies = 4 - len(rgb_orig_list) if num_dummies: rgb_orig_list.extend([dummy] * num_dummies) rgb_mask_list.extend([dummy] * num_dummies) rgb_pred_list.extend([dummy] * num_dummies) outputs = rgb_orig_list + rgb_mask_list + rgb_pred_list return outputs def predict_on_images(data_file: str | Path, config_path: str, checkpoint: str): try: data_file = data_file.name print('Path extracted from example') except: print('Files submitted through UI') # Get parameters -------- print('This is the printout', data_file) with open(config_path, "r") as f: config_dict = yaml.safe_load(f) # Load model --------------------------------------------------------------------------------- print(f'Loading model') lightning_model = LightningInferenceModel.from_config(config_path, checkpoint) img_size = 512 # Size of BurnScars print(f'Model loaded') # Loading data --------------------------------------------------------------------------------- input_data, temporal_coords, location_coords, meta_data = load_example(file_paths=[data_file]) if input_data.shape[1] != 6: raise Exception(f'Input data has {input_data.shape[1]} channels. Expect six Prithvi channels.') if input_data.mean() > 1: input_data = input_data / 10000 # Convert to range 0-1 # Running model -------------------------------------------------------------------------------- lightning_model.model.eval() channels = [config_dict['data']['init_args']['output_bands'].index(b) for b in ["RED", "GREEN", "BLUE"]] # BGR -> RGB pred = run_model(input_data, lightning_model.model, lightning_model.datamodule, img_size) if input_data.mean() < 1: input_data = input_data * 10000 # Scale to 0-10000 # Extract RGB images for display rgb_orig = process_channel_group( orig_img=torch.Tensor(input_data[0, :, 0, ...]), channels=channels, ) out_rgb_orig = _convert_np_uint8(rgb_orig).transpose(1, 2, 0) out_pred_rgb = _convert_np_uint8(pred).repeat(3, axis=0).transpose(1, 2, 0) pred[pred == 0.] = np.nan img_pred = rgb_orig * 0.6 + pred * 0.4 img_pred[img_pred.isnan()] = rgb_orig[img_pred.isnan()] out_img_pred = _convert_np_uint8(img_pred).transpose(1, 2, 0) outputs = [out_rgb_orig] + [out_pred_rgb] + [out_img_pred] print("Done!") return outputs run_inference = partial(predict_on_images, config_path=config_path, checkpoint=checkpoint) with gr.Blocks() as demo: gr.Markdown(value='# Prithvi-EO-2.0 BurnScars Demo') gr.Markdown(value=''' Prithvi-EO-2.0 is the second generation EO foundation model developed by the IBM and NASA team. This demo showcases the fine-tuned Prithvi-EO-2.0-300M model to detect burn scars using HLS imagery from on the [HLS Burn Scars dataset](https://huggingface.co/datasets/ibm-nasa-geospatial/hls_burn_scars). More details can be found [here](https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-300M-BurnScars).\n The user needs to provide a HLS image with the six Prithvi bands (Blue, Green, Red, Narrow NIR, SWIR, SWIR 2). We recommend submitting images of 500 to ~1000 pixels for faster processing time. Images bigger than 512x512 are processed using a sliding window approach which can lead to artefacts between patches.\n Some example images are provided at the end of this page. ''') with gr.Row(): with gr.Column(): inp_file = gr.File(elem_id='file') # inp_slider = gr.Slider(0, 100, value=50, label="Mask ratio", info="Choose ratio of masking between 0 and 100", elem_id='slider'), btn = gr.Button("Submit") with gr.Row(): gr.Markdown(value='## Input image') gr.Markdown(value='## Prediction*') gr.Markdown(value='## Overlay') with gr.Row(): original = gr.Image(image_mode='RGB', show_label=False, show_fullscreen_button=False) predicted = gr.Image(image_mode='RGB', show_label=False, show_fullscreen_button=False) overlay = gr.Image(image_mode='RGB', show_label=False, show_fullscreen_button=False) gr.Markdown(value='\* White = burned; Black = not burned') btn.click(fn=run_inference, inputs=inp_file, outputs=[original] + [predicted] + [overlay]) with gr.Row(): gr.Examples(examples=[ os.path.join(os.path.dirname(__file__), "examples/subsetted_512x512_HLS.S30.T10SEH.2018190.v1.4_merged.tif"), os.path.join(os.path.dirname(__file__), "examples/subsetted_512x512_HLS.S30.T10SFH.2018185.v1.4_merged.tif"), os.path.join(os.path.dirname(__file__), "examples/subsetted_512x512_HLS.S30.T10SGF.2020217.v1.4_merged.tif")], inputs=inp_file, outputs=[original] + [predicted] + [overlay], fn=run_inference, cache_examples=True ) demo.launch(ssr_mode=False)