blumenstiel commited on
Commit
c565e6b
·
1 Parent(s): c5f7af6

Add demo code

Browse files
Files changed (4) hide show
  1. .gitattributes +1 -0
  2. Dockerfile +38 -0
  3. app.py +173 -0
  4. requirements.txt +8 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.tif filter=lfs diff=lfs merge=lfs -text
Dockerfile ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM ubuntu:22.04
2
+
3
+
4
+ RUN apt-get update && apt-get install --no-install-recommends -y \
5
+ build-essential \
6
+ python3.9 \
7
+ python3-pip \
8
+ git \
9
+ && apt-get clean && rm -rf /var/lib/apt/lists/*
10
+
11
+ WORKDIR /code
12
+
13
+ COPY ./requirements.txt /code/requirements.txt
14
+
15
+ # Set up a new user named "user" with user ID 1000
16
+ RUN useradd -m -u 1000 user
17
+ # Switch to the "user" user
18
+ USER user
19
+ # Set home to the user's home directory
20
+ ENV HOME=/home/user \
21
+ PATH=/home/user/.local/bin:$PATH \
22
+ PYTHONPATH=$HOME/app \
23
+ PYTHONUNBUFFERED=1 \
24
+ GRADIO_ALLOW_FLAGGING=never \
25
+ GRADIO_NUM_PORTS=1 \
26
+ GRADIO_SERVER_NAME=0.0.0.0 \
27
+ GRADIO_THEME=huggingface \
28
+ SYSTEM=spaces
29
+
30
+ RUN pip3 install --no-cache-dir --upgrade -r /code/requirements.txt
31
+
32
+ # Set the working directory to the user's home directory
33
+ WORKDIR $HOME/app
34
+
35
+ # Copy the current directory contents into the container at $HOME/app setting the owner to the user
36
+ COPY --chown=user . $HOME/app
37
+
38
+ CMD ["python3", "app.py"]
app.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ import torch
4
+ import yaml
5
+ import numpy as np
6
+ import gradio as gr
7
+ from pathlib import Path
8
+ from einops import rearrange
9
+ from functools import partial
10
+ from huggingface_hub import hf_hub_download
11
+ from terratorch.cli_tools import LightningInferenceModel
12
+
13
+ # pull files from hub
14
+ token = os.environ.get("HF_TOKEN", None)
15
+ config_path = hf_hub_download(repo_id="ibm-nasa-geospatial/Prithvi-EO-2.0-300M-BurnScars",
16
+ filename="burn_scars_config.yaml", token=token)
17
+ checkpoint = hf_hub_download(repo_id="ibm-nasa-geospatial/Prithvi-EO-2.0-300M-BurnScars",
18
+ filename='Prithvi_EO_V2_300M_BurnScars.pt', token=token)
19
+ model_inference = hf_hub_download(repo_id="ibm-nasa-geospatial/Prithvi-EO-2.0-300M-BurnScars",
20
+ filename='inference.py', token=token)
21
+ os.system(f'cp {model_inference} .')
22
+
23
+ from inference import process_channel_group, _convert_np_uint8, load_example, run_model
24
+
25
+ def extract_rgb_imgs(input_img, pred_img, channels):
26
+ """ Wrapper function to save Geotiff images (original, reconstructed, masked) per timestamp.
27
+ Args:
28
+ input_img: input torch.Tensor with shape (C, H, W).
29
+ rec_img: reconstructed torch.Tensor with shape (C, T, H, W).
30
+ pred_img: mask torch.Tensor with shape (C, T, H, W).
31
+ channels: list of indices representing RGB channels.
32
+ mean: list of mean values for each band.
33
+ std: list of std values for each band.
34
+ output_dir: directory where to save outputs.
35
+ meta_data: list of dicts with geotiff meta info.
36
+ """
37
+ rgb_orig_list = []
38
+ rgb_mask_list = []
39
+ rgb_pred_list = []
40
+
41
+ for t in range(input_img.shape[1]):
42
+ rgb_orig, rgb_pred = process_channel_group(orig_img=input_img[:, t, :, :],
43
+ new_img=rec_img[:, t, :, :],
44
+ channels=channels,
45
+ mean=mean,
46
+ std=std)
47
+
48
+ rgb_mask = mask_img[channels, t, :, :] * rgb_orig
49
+
50
+ # extract images
51
+ rgb_orig_list.append(_convert_np_uint8(rgb_orig).transpose(1, 2, 0))
52
+ rgb_mask_list.append(_convert_np_uint8(rgb_mask).transpose(1, 2, 0))
53
+ rgb_pred_list.append(_convert_np_uint8(rgb_pred).transpose(1, 2, 0))
54
+
55
+ # Add white dummy image values for missing timestamps
56
+ dummy = np.ones((20, 20), dtype=np.uint8) * 255
57
+ num_dummies = 4 - len(rgb_orig_list)
58
+ if num_dummies:
59
+ rgb_orig_list.extend([dummy] * num_dummies)
60
+ rgb_mask_list.extend([dummy] * num_dummies)
61
+ rgb_pred_list.extend([dummy] * num_dummies)
62
+
63
+ outputs = rgb_orig_list + rgb_mask_list + rgb_pred_list
64
+
65
+ return outputs
66
+
67
+
68
+ def predict_on_images(data_file: str | Path, config_path: str, checkpoint: str):
69
+ try:
70
+ data_file = data_file.name
71
+ print('Path extracted from example')
72
+ except:
73
+ print('Files submitted through UI')
74
+
75
+ # Get parameters --------
76
+ print('This is the printout', data_file)
77
+
78
+ with open(config_path, "r") as f:
79
+ config_dict = yaml.safe_load(f)
80
+
81
+ # Load model ---------------------------------------------------------------------------------
82
+
83
+ lightning_model = LightningInferenceModel.from_config(config_path, checkpoint)
84
+ img_size = 256 # Size of Sen1Floods11
85
+
86
+ # Loading data ---------------------------------------------------------------------------------
87
+
88
+ input_data, temporal_coords, location_coords, meta_data = load_example(file_paths=[data_file])
89
+
90
+ if input_data.shape[1] != 6:
91
+ raise Exception(f'Input data has {input_data.shape[1]} channels. Expect six Prithvi channels.')
92
+
93
+ if input_data.mean() > 1:
94
+ input_data = input_data / 10000 # Convert to range 0-1
95
+
96
+ # Running model --------------------------------------------------------------------------------
97
+
98
+ lightning_model.model.eval()
99
+
100
+ channels = [config_dict['data']['init_args']['bands'].index(b) for b in ["RED", "GREEN", "BLUE"]] # BGR -> RGB
101
+
102
+ pred = run_model(input_data, temporal_coords, location_coords,
103
+ lightning_model.model, lightning_model.datamodule, img_size)
104
+
105
+ if input_data.mean() < 1:
106
+ input_data = input_data * 10000 # Scale to 0-10000
107
+
108
+ # Extract RGB images for display
109
+ rgb_orig = process_channel_group(
110
+ orig_img=torch.Tensor(input_data[0, :, 0, ...]),
111
+ channels=channels,
112
+ )
113
+ out_rgb_orig = _convert_np_uint8(rgb_orig).transpose(1, 2, 0)
114
+ out_pred_rgb = _convert_np_uint8(pred).repeat(3, axis=0).transpose(1, 2, 0)
115
+
116
+ pred[pred == 0.] = np.nan
117
+ img_pred = rgb_orig * 0.6 + pred * 0.4
118
+ img_pred[img_pred.isnan()] = rgb_orig[img_pred.isnan()]
119
+
120
+ out_img_pred = _convert_np_uint8(img_pred).transpose(1, 2, 0)
121
+
122
+ outputs = [out_rgb_orig] + [out_pred_rgb] + [out_img_pred]
123
+
124
+ print("Done!")
125
+
126
+ return outputs
127
+
128
+
129
+ run_inference = partial(predict_on_images, config_path=config_path, checkpoint=checkpoint)
130
+
131
+ with gr.Blocks() as demo:
132
+ gr.Markdown(value='# Prithvi-EO-2.0 BurnScars Demo')
133
+ gr.Markdown(value='''
134
+ Prithvi-EO-2.0 is the second generation EO foundation model developed by the IBM and NASA team.
135
+ This demo showcases the fine-tuned Prithvi-EO-2.0-300M model to detect burn scars using HLS imagery from on the [HLS Burn Scars dataset](https://huggingface.co/datasets/ibm-nasa-geospatial/hls_burn_scars). More details can be found [here](https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-300M-BurnScars).\n
136
+
137
+ The user needs to provide a HLS image with the six Prithvi bands (Blue, Green, Red, Narrow NIR, SWIR, SWIR 2).
138
+ We recommend submitting images of 500 to ~1000 pixels for faster processing time. Images bigger than 512x512 are processed using a sliding window approach which can lead to artefacts between patches.\n
139
+ Some example images are provided at the end of this page.
140
+ ''')
141
+ with gr.Row():
142
+ with gr.Column():
143
+ inp_file = gr.File(elem_id='file')
144
+ # inp_slider = gr.Slider(0, 100, value=50, label="Mask ratio", info="Choose ratio of masking between 0 and 100", elem_id='slider'),
145
+ btn = gr.Button("Submit")
146
+ with gr.Row():
147
+ gr.Markdown(value='## Input image')
148
+ gr.Markdown(value='## Prediction*')
149
+ gr.Markdown(value='## Overlay')
150
+
151
+ with gr.Row():
152
+ original = gr.Image(image_mode='RGB', show_label=False, show_fullscreen_button=False)
153
+ predicted = gr.Image(image_mode='RGB', show_label=False, show_fullscreen_button=False)
154
+ overlay = gr.Image(image_mode='RGB', show_label=False, show_fullscreen_button=False)
155
+
156
+ gr.Markdown(value='\* White = burned; Black = no burned')
157
+
158
+ btn.click(fn=run_inference,
159
+ inputs=inp_file,
160
+ outputs=[original] + [predicted] + [overlay])
161
+
162
+ with gr.Row():
163
+ gr.Examples(examples=[
164
+ os.path.join(os.path.dirname(__file__), "examples/subsetted_512x512_HLS.S30.T10SEH.2018190.v1.4_merged.tif"),
165
+ os.path.join(os.path.dirname(__file__), "examples/subsetted_512x512_HLS.S30.T10SFF.2018190.v1.4_merged.tif"),
166
+ os.path.join(os.path.dirname(__file__), "examples/subsetted_512x512_HLS.S30.T10SGF.2020217.v1.4_merged.tif")],
167
+ inputs=inp_file,
168
+ outputs=[original] + [predicted] + [overlay],
169
+ fn=run_inference,
170
+ cache_examples=True
171
+ )
172
+
173
+ demo.launch(share=True, ssr_mode=False)
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ timm
4
+ rasterio
5
+ einops
6
+ huggingface_hub
7
+ gradio
8
+ git+https://github.com/IBM/terratorch.git