Mars commited on
Commit
e99b724
·
1 Parent(s): 078b49f

Add application file

Browse files
.gitattributes copy ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
Dockerfile ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.8
2
+
3
+
4
+ RUN apt-get update && apt-get install --no-install-recommends -y \
5
+ build-essential \
6
+ # python3.8 \
7
+ # python3-pip \
8
+ # python3-setuptools \
9
+ git \
10
+ wget \
11
+ && apt-get clean && rm -rf /var/lib/apt/lists/*
12
+
13
+ RUN apt-get update && apt-get install ffmpeg libsm6 libxext6 -y
14
+
15
+ WORKDIR /code
16
+
17
+ RUN useradd -m -u 1000 user
18
+
19
+ # Switch to the "user" user
20
+ USER user
21
+ # Set home to the user's home directory
22
+ ENV HOME=/home/user \
23
+ PATH=/home/user/.local/bin:$PATH \
24
+ PYTHONPATH=$HOME/app \
25
+ PYTHONUNBUFFERED=1 \
26
+ GRADIO_ALLOW_FLAGGING=never \
27
+ GRADIO_NUM_PORTS=1 \
28
+ GRADIO_SERVER_NAME=0.0.0.0 \
29
+ GRADIO_THEME=huggingface \
30
+ SYSTEM=spaces
31
+
32
+ # RUN conda install python=3.8
33
+
34
+ RUN pip install setuptools-rust
35
+ RUN pip install torch==1.11.0+cu115 torchvision==0.12.0+cu115 --extra-index-url https://download.pytorch.org/whl/cu115
36
+ RUN pip install gradio scikit-image pillow openmim
37
+ RUN pip install --upgrade setuptools
38
+
39
+ WORKDIR /home/user
40
+
41
+ RUN --mount=type=secret,id=git_token,mode=0444,required=true \
42
+ git clone --branch mmseg-only https://$(cat /run/secrets/git_token)@github.com/NASA-IMPACT/hls-foundation-os.git
43
+
44
+
45
+ WORKDIR hls-foundation-os
46
+
47
+ RUN git checkout 9968269915db8402bf4a6d0549df9df57d489e5a
48
+
49
+ RUN pip install -e .
50
+
51
+ RUN mim install mmcv-full==1.6.2 -f https://download.openmmlab.com/mmcv/dist/11.5/1.11.0/index.html
52
+
53
+ # Set the working directory to the user's home directory
54
+ WORKDIR $HOME/app
55
+
56
+ # ENV LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/code/miniconda/lib"
57
+
58
+ # Copy the current directory contents into the container at $HOME/app setting the owner to the user
59
+
60
+ COPY --chown=user . $HOME/app
61
+
62
+ CMD ["python3", "app.py"]
Legend.png ADDED
README copy.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Prithvi 100M Multi Temporal Crop Classification Demo
3
+ emoji: 📚
4
+ colorFrom: purple
5
+ colorTo: red
6
+ sdk: docker
7
+ pinned: false
8
+ license: apache-2.0
9
+ ---
10
+
11
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
12
+
13
+
app.py ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ######### pull files
2
+ import os
3
+ from huggingface_hub import hf_hub_download
4
+ #config_path=hf_hub_download(repo_id="ibm-nasa-geospatial/Prithvi-100M-multi-temporal-crop-classification",
5
+ # filename="multi_temporal_crop_classification_Prithvi_100M.py",
6
+ # token=os.environ.get("token"))
7
+
8
+ ckpt=hf_hub_download(repo_id="ibm-nasa-geospatial/Prithvi-100M-multi-temporal-crop-classification",
9
+ filename='multi_temporal_crop_classification_Prithvi_100M.pth',
10
+ token=os.environ.get("token"))
11
+
12
+ config_path="multi_temporal_crop_classification_Prithvi_100M.py"
13
+ ##########
14
+ import argparse
15
+ from mmcv import Config
16
+
17
+ from mmseg.models import build_segmentor
18
+
19
+ from mmseg.datasets.pipelines import Compose, LoadImageFromFile
20
+
21
+ import rasterio
22
+ import torch
23
+
24
+ from mmseg.apis import init_segmentor
25
+
26
+ from mmcv.parallel import collate, scatter
27
+
28
+ import numpy as np
29
+ import glob
30
+ import os
31
+
32
+ import time
33
+
34
+ import numpy as np
35
+ import gradio as gr
36
+ from functools import partial
37
+
38
+ import pdb
39
+ import matplotlib.pyplot as plt
40
+
41
+ from skimage import exposure
42
+
43
+ import pandas as pd
44
+ from vega_datasets import data
45
+
46
+
47
+ cdl_color_map = [{'value': 1, 'label': 'Natural vegetation', 'rgb': (233,255,190)},
48
+ {'value': 2, 'label': 'Forest', 'rgb': (149,206,147)},
49
+ {'value': 3, 'label': 'Corn', 'rgb': (255,212,0)},
50
+ {'value': 4, 'label': 'Soybeans', 'rgb': (38,115,0)},
51
+ {'value': 5, 'label': 'Wetlands', 'rgb': (128,179,179)},
52
+ {'value': 6, 'label': 'Developed/Barren', 'rgb': (156,156,156)},
53
+ {'value': 7, 'label': 'Open Water', 'rgb': (77,112,163)},
54
+ {'value': 8, 'label': 'Winter Wheat', 'rgb': (168,112,0)},
55
+ {'value': 9, 'label': 'Alfalfa', 'rgb': (255,168,227)},
56
+ {'value': 10, 'label': 'Fallow/Idle cropland', 'rgb': (191,191,122)},
57
+ {'value': 11, 'label': 'Cotton', 'rgb':(255,38,38)},
58
+ {'value': 12, 'label': 'Sorghum', 'rgb':(255,158,15)},
59
+ {'value': 13, 'label': 'Other', 'rgb':(0,175,77)}]
60
+
61
+
62
+ def apply_color_map(rgb, color_map=cdl_color_map):
63
+
64
+
65
+ rgb_mapped = rgb.copy()
66
+
67
+ for map_tmp in cdl_color_map:
68
+
69
+ for i in range(3):
70
+ rgb_mapped[i] = np.where((rgb[0] == map_tmp['value']) & (rgb[1] == map_tmp['value']) & (rgb[2] == map_tmp['value']), map_tmp['rgb'][i], rgb_mapped[i])
71
+
72
+ return rgb_mapped
73
+
74
+
75
+ def stretch_rgb(rgb):
76
+
77
+ ls_pct=0
78
+ pLow, pHigh = np.percentile(rgb[~np.isnan(rgb)], (ls_pct,100-ls_pct))
79
+ img_rescale = exposure.rescale_intensity(rgb, in_range=(pLow,pHigh))
80
+
81
+ return img_rescale
82
+
83
+ def open_tiff(fname):
84
+
85
+ with rasterio.open(fname, "r") as src:
86
+
87
+ data = src.read()
88
+
89
+ return data
90
+
91
+ def write_tiff(img_wrt, filename, metadata):
92
+
93
+ """
94
+ It writes a raster image to file.
95
+
96
+ :param img_wrt: numpy array containing the data (can be 2D for single band or 3D for multiple bands)
97
+ :param filename: file path to the output file
98
+ :param metadata: metadata to use to write the raster to disk
99
+ :return:
100
+ """
101
+
102
+ with rasterio.open(filename, "w", **metadata) as dest:
103
+
104
+ if len(img_wrt.shape) == 2:
105
+
106
+ img_wrt = img_wrt[None]
107
+
108
+ for i in range(img_wrt.shape[0]):
109
+ dest.write(img_wrt[i, :, :], i + 1)
110
+
111
+ return filename
112
+
113
+
114
+ def get_meta(fname):
115
+
116
+ with rasterio.open(fname, "r") as src:
117
+
118
+ meta = src.meta
119
+
120
+ return meta
121
+
122
+ def preprocess_example(example_list):
123
+
124
+ example_list = [os.path.join(os.path.abspath(''), x) for x in example_list]
125
+
126
+ return example_list
127
+
128
+
129
+ def inference_segmentor(model, imgs, custom_test_pipeline=None):
130
+ """Inference image(s) with the segmentor.
131
+
132
+ Args:
133
+ model (nn.Module): The loaded segmentor.
134
+ imgs (str/ndarray or list[str/ndarray]): Either image files or loaded
135
+ images.
136
+
137
+ Returns:
138
+ (list[Tensor]): The segmentation result.
139
+ """
140
+ cfg = model.cfg
141
+ device = next(model.parameters()).device # model device
142
+ # build the data pipeline
143
+ test_pipeline = [LoadImageFromFile()] + cfg.data.test.pipeline[1:] if custom_test_pipeline == None else custom_test_pipeline
144
+ test_pipeline = Compose(test_pipeline)
145
+ # prepare data
146
+ data = []
147
+ imgs = imgs if isinstance(imgs, list) else [imgs]
148
+ for img in imgs:
149
+ img_data = {'img_info': {'filename': img}}
150
+ img_data = test_pipeline(img_data)
151
+ data.append(img_data)
152
+ # print(data.shape)
153
+
154
+ data = collate(data, samples_per_gpu=len(imgs))
155
+ if next(model.parameters()).is_cuda:
156
+ # data = collate(data, samples_per_gpu=len(imgs))
157
+ # scatter to specified GPU
158
+ data = scatter(data, [device])[0]
159
+ else:
160
+ # img_metas = scatter(data['img_metas'],'cpu')
161
+ # data['img_metas'] = [i.data[0] for i in data['img_metas']]
162
+
163
+ img_metas = data['img_metas'].data[0]
164
+ img = data['img']
165
+ data = {'img': img, 'img_metas':img_metas}
166
+
167
+ with torch.no_grad():
168
+ result = model(return_loss=False, rescale=True, **data)
169
+ return result
170
+
171
+
172
+ def process_rgb(input, mask, indexes):
173
+
174
+
175
+ rgb = stretch_rgb((input[indexes, :, :].transpose((1,2,0))/10000*255).astype(np.uint8))
176
+ rgb = np.where(mask.transpose((1,2,0)) == 1, 0, rgb)
177
+ rgb = np.where(rgb < 0, 0, rgb)
178
+ rgb = np.where(rgb > 255, 255, rgb)
179
+
180
+ return rgb
181
+
182
+ def inference_on_file(target_image, model, custom_test_pipeline):
183
+
184
+ target_image = target_image.name
185
+ time_taken=-1
186
+ st = time.time()
187
+ print('Running inference...')
188
+ try:
189
+ result = inference_segmentor(model, target_image, custom_test_pipeline)
190
+ except:
191
+ print('Error: Try different channels order.')
192
+ model.cfg.data.test.pipeline[0]['channels_last'] = True
193
+ result = inference_segmentor(model, target_image, custom_test_pipeline)
194
+ print("Output has shape: " + str(result[0].shape))
195
+
196
+ ##### get metadata mask
197
+ input = open_tiff(target_image)
198
+ meta = get_meta(target_image)
199
+ mask = np.where(input == meta['nodata'], 1, 0)
200
+ mask = np.max(mask, axis=0)[None]
201
+
202
+ rgb1 = process_rgb(input, mask, [2, 1, 0])
203
+ rgb2 = process_rgb(input, mask, [8, 7, 6])
204
+ rgb3 = process_rgb(input, mask, [14, 13, 12])
205
+
206
+ result[0] = np.where(mask == 1, 0, result[0])
207
+
208
+ et = time.time()
209
+ time_taken = np.round(et - st, 1)
210
+ print(f'Inference completed in {str(time_taken)} seconds')
211
+
212
+ output=result[0][0] + 1
213
+ output = np.vstack([output[None], output[None], output[None]]).astype(np.uint8)
214
+
215
+
216
+ output=apply_color_map(output).transpose((1,2,0))
217
+
218
+ return rgb1,rgb2,rgb3,output
219
+
220
+ def process_test_pipeline(custom_test_pipeline, bands=None):
221
+
222
+ # change extracted bands if necessary
223
+ if bands is not None:
224
+
225
+ extract_index = [i for i, x in enumerate(custom_test_pipeline) if x['type'] == 'BandsExtract' ]
226
+
227
+ if len(extract_index) > 0:
228
+
229
+ custom_test_pipeline[extract_index[0]]['bands'] = eval(bands)
230
+
231
+ collect_index = [i for i, x in enumerate(custom_test_pipeline) if x['type'].find('Collect') > -1]
232
+
233
+ # adapt collected keys if necessary
234
+ if len(collect_index) > 0:
235
+
236
+ keys = ['img_info', 'filename', 'ori_filename', 'img', 'img_shape', 'ori_shape', 'pad_shape', 'scale_factor', 'img_norm_cfg']
237
+ custom_test_pipeline[collect_index[0]]['meta_keys'] = keys
238
+
239
+ return custom_test_pipeline
240
+
241
+ config = Config.fromfile(config_path)
242
+ config.model.backbone.pretrained=None
243
+ model = init_segmentor(config, ckpt, device='cpu')
244
+ custom_test_pipeline=process_test_pipeline(model.cfg.data.test.pipeline, None)
245
+
246
+ func = partial(inference_on_file, model=model, custom_test_pipeline=custom_test_pipeline)
247
+
248
+
249
+ stocks = data.stocks()
250
+ gapminder = data.gapminder()
251
+ gapminder = gapminder.loc[
252
+ gapminder.country.isin(["Argentina", "Australia", "Afghanistan"])
253
+ ]
254
+ climate = data.climate()
255
+ seattle_weather = data.seattle_weather()
256
+
257
+ simple = pd.DataFrame(
258
+ {
259
+ "a": ["A", "B", "C", "D", "E", "F", "G", "H", "I"],
260
+ "b": [28, 55, 43, 91, 81, 53, 19, 87, 52],
261
+ }
262
+ )
263
+
264
+ with gr.Blocks() as demo:
265
+
266
+ gr.Markdown(value='# Prithvi multi temporal crop classification')
267
+ gr.Markdown(value='''Prithvi is a first-of-its-kind temporal Vision transformer pretrained by the IBM and NASA team on continental US Harmonised Landsat Sentinel 2 (HLS) data. This demo showcases how the model was finetuned to classify crop and other land use categories using multi temporal data. More detailes can be found [here](https://huggingface.co/ibm-nasa-geospatial/Prithvi-100M-multi-temporal-crop-classification).\n
268
+ The user needs to provide an HLS geotiff image, including 18 bands for 3 time-step, and each time-step includes the channels described above (Blue, Green, Red, Narrow NIR, SWIR, SWIR 2) in order.
269
+ ''')
270
+ with gr.Row():
271
+ with gr.Column():
272
+ inp = gr.File()
273
+ btn = gr.Button("Submit")
274
+
275
+ with gr.Column():
276
+ inp1=gr.Image(image_mode='RGB', scale=10, label='T1')
277
+ inp2=gr.Image(image_mode='RGB', scale=10, label='T2')
278
+ inp3=gr.Image(image_mode='RGB', scale=10, label='T3')
279
+ out = gr.Image(image_mode='RGB', scale=10, label='Model prediction')
280
+ # gr.Image(value='Legend.png', image_mode='RGB', scale=2, show_label=False)
281
+
282
+ btn.click(fn=func, inputs=inp, outputs=[inp1, inp2, inp3, out])
283
+
284
+ with gr.Row():
285
+ with gr.Column():
286
+ with gr.Row():
287
+ gr.BarPlot(simple,
288
+ x="a",
289
+ y="b",
290
+ title="Simple Bar Plot with made up data",
291
+ tooltip=["a", "b"],
292
+ y_lim=[20, 100],)
293
+ with gr.Row():
294
+ gr.LinePlot(simple,
295
+ x='a',
296
+ y='b')
297
+
298
+ with gr.Column():
299
+ gr.Markdown(value='### Model prediction legend')
300
+ gr.Image(value='Legend.png', image_mode='RGB', show_label=False)
301
+
302
+
303
+ demo.launch()
multi_temporal_crop_classification_Prithvi_100M.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ dist_params = dict(backend='nccl')
4
+ log_level = 'INFO'
5
+ load_from = None
6
+ resume_from = None
7
+ cudnn_benchmark = True
8
+ custom_imports = dict(imports=['geospatial_fm'])
9
+ num_frames = 3
10
+ img_size = 224
11
+ num_workers = 2
12
+
13
+ # model
14
+ # TO BE DEFINED BY USER: model path
15
+ pretrained_weights_path = '<path to pretrained weights>'
16
+ num_layers = 6
17
+ patch_size = 16
18
+ embed_dim = 768
19
+ num_heads = 8
20
+ tubelet_size = 1
21
+ max_epochs = 80
22
+ eval_epoch_interval = 5
23
+
24
+ loss_weights_multi = [
25
+ 0.386375, 0.661126, 0.548184, 0.640482, 0.876862, 0.925186, 3.249462,
26
+ 1.542289, 2.175141, 2.272419, 3.062762, 3.626097, 1.198702
27
+ ]
28
+ loss_func = dict(
29
+ type='CrossEntropyLoss',
30
+ use_sigmoid=False,
31
+ class_weight=loss_weights_multi,
32
+ avg_non_ignore=True)
33
+ output_embed_dim = embed_dim*num_frames
34
+
35
+
36
+ # TO BE DEFINED BY USER: Save directory
37
+ experiment = '<experiment name>'
38
+ project_dir = '<project directory name>'
39
+ work_dir = os.path.join(project_dir, experiment)
40
+ save_path = work_dir
41
+
42
+
43
+ gpu_ids = range(0, 1)
44
+ dataset_type = 'GeospatialDataset'
45
+
46
+ # TO BE DEFINED BY USER: data directory
47
+ data_root = '<path to data root>'
48
+
49
+ splits = dict(
50
+ train='<path to train split>',
51
+ val= '<path to val split>',
52
+ test= '<path to test split>'
53
+ )
54
+
55
+
56
+ img_norm_cfg = dict(
57
+ means=[
58
+ 494.905781, 815.239594, 924.335066, 2968.881459, 2634.621962,
59
+ 1739.579917, 494.905781, 815.239594, 924.335066, 2968.881459,
60
+ 2634.621962, 1739.579917, 494.905781, 815.239594, 924.335066,
61
+ 2968.881459, 2634.621962, 1739.579917
62
+ ],
63
+ stds=[
64
+ 284.925432, 357.84876, 575.566823, 896.601013, 951.900334, 921.407808,
65
+ 284.925432, 357.84876, 575.566823, 896.601013, 951.900334, 921.407808,
66
+ 284.925432, 357.84876, 575.566823, 896.601013, 951.900334, 921.407808
67
+ ])
68
+
69
+ bands = [0, 1, 2, 3, 4, 5]
70
+
71
+ tile_size = 224
72
+ orig_nsize = 512
73
+ crop_size = (tile_size, tile_size)
74
+ train_pipeline = [
75
+ dict(type='LoadGeospatialImageFromFile', to_float32=True),
76
+ dict(type='LoadGeospatialAnnotations', reduce_zero_label=True),
77
+ dict(type='RandomFlip', prob=0.5),
78
+ dict(type='ToTensor', keys=['img', 'gt_semantic_seg']),
79
+ # to channels first
80
+ dict(type="TorchPermute", keys=["img"], order=(2, 0, 1)),
81
+ dict(type='TorchNormalize', **img_norm_cfg),
82
+ dict(type='TorchRandomCrop', crop_size=crop_size),
83
+ dict(type='Reshape', keys=['img'], new_shape=(len(bands), num_frames, tile_size, tile_size)),
84
+ dict(type='Reshape', keys=['gt_semantic_seg'], new_shape=(1, tile_size, tile_size)),
85
+ dict(type='CastTensor', keys=['gt_semantic_seg'], new_type="torch.LongTensor"),
86
+ dict(type='Collect', keys=['img', 'gt_semantic_seg']),
87
+ ]
88
+
89
+ test_pipeline = [
90
+ dict(type='LoadGeospatialImageFromFile', to_float32=True),
91
+ dict(type='ToTensor', keys=['img']),
92
+ # to channels first
93
+ dict(type="TorchPermute", keys=["img"], order=(2, 0, 1)),
94
+ dict(type='TorchNormalize', **img_norm_cfg),
95
+ dict(type='Reshape', keys=['img'], new_shape=(len(bands), num_frames, -1, -1), look_up = {'2': 1, '3': 2}),
96
+ dict(type='CastTensor', keys=['img'], new_type="torch.FloatTensor"),
97
+ dict(type='CollectTestList', keys=['img'],
98
+ meta_keys=['img_info', 'seg_fields', 'img_prefix', 'seg_prefix', 'filename', 'ori_filename', 'img',
99
+ 'img_shape', 'ori_shape', 'pad_shape', 'scale_factor', 'img_norm_cfg']),
100
+ ]
101
+
102
+ CLASSES = ('Natural Vegetation',
103
+ 'Forest',
104
+ 'Corn',
105
+ 'Soybeans',
106
+ 'Wetlands',
107
+ 'Developed/Barren',
108
+ 'Open Water',
109
+ 'Winter Wheat',
110
+ 'Alfalfa',
111
+ 'Fallow/Idle Cropland',
112
+ 'Cotton',
113
+ 'Sorghum',
114
+ 'Other')
115
+
116
+ dataset = 'GeospatialDataset'
117
+
118
+ data = dict(
119
+ samples_per_gpu=8,
120
+ workers_per_gpu=4,
121
+ train=dict(
122
+ type=dataset,
123
+ CLASSES=CLASSES,
124
+ reduce_zero_label=True,
125
+ data_root=data_root,
126
+ img_dir='training_chips',
127
+ ann_dir='training_chips',
128
+ pipeline=train_pipeline,
129
+ img_suffix='_merged.tif',
130
+ seg_map_suffix='.mask.tif',
131
+ split=splits['train']),
132
+ val=dict(
133
+ type=dataset,
134
+ CLASSES=CLASSES,
135
+ reduce_zero_label=True,
136
+ data_root=data_root,
137
+ img_dir='validation_chips',
138
+ ann_dir='validation_chips',
139
+ pipeline=test_pipeline,
140
+ img_suffix='_merged.tif',
141
+ seg_map_suffix='.mask.tif',
142
+ split=splits['val']
143
+ ),
144
+ test=dict(
145
+ type=dataset,
146
+ CLASSES=CLASSES,
147
+ reduce_zero_label=True,
148
+ data_root=data_root,
149
+ img_dir='validation_chips',
150
+ ann_dir='validation_chips',
151
+ pipeline=test_pipeline,
152
+ img_suffix='_merged.tif',
153
+ seg_map_suffix='.mask.tif',
154
+ split=splits['val']
155
+ ))
156
+
157
+ optimizer = dict(
158
+ type='Adam', lr=1.5e-05, betas=(0.9, 0.999), weight_decay=0.05)
159
+ optimizer_config = dict(grad_clip=None)
160
+ lr_config = dict(
161
+ policy='poly',
162
+ warmup='linear',
163
+ warmup_iters=1500,
164
+ warmup_ratio=1e-06,
165
+ power=1.0,
166
+ min_lr=0.0,
167
+ by_epoch=False)
168
+ log_config = dict(
169
+ interval=10,
170
+ hooks=[dict(type='TextLoggerHook'),
171
+ dict(type='TensorboardLoggerHook')])
172
+
173
+ checkpoint_config = dict(
174
+ by_epoch=True,
175
+ interval=100,
176
+ out_dir=save_path)
177
+
178
+ evaluation = dict(interval=eval_epoch_interval, metric='mIoU', pre_eval=True, save_best='mIoU', by_epoch=True)
179
+ reduce_train_set = dict(reduce_train_set=False)
180
+ reduce_factor = dict(reduce_factor=1)
181
+ runner = dict(type='EpochBasedRunner', max_epochs=max_epochs)
182
+ workflow = [('train', 1)]
183
+ norm_cfg = dict(type='BN', requires_grad=True)
184
+
185
+ model = dict(
186
+ type='TemporalEncoderDecoder',
187
+ frozen_backbone=False,
188
+ backbone=dict(
189
+ type='TemporalViTEncoder',
190
+ pretrained=pretrained_weights_path,
191
+ img_size=img_size,
192
+ patch_size=patch_size,
193
+ num_frames=num_frames,
194
+ tubelet_size=1,
195
+ in_chans=len(bands),
196
+ embed_dim=embed_dim,
197
+ depth=6,
198
+ num_heads=num_heads,
199
+ mlp_ratio=4.0,
200
+ norm_pix_loss=False),
201
+ neck=dict(
202
+ type='ConvTransformerTokensToEmbeddingNeck',
203
+ embed_dim=embed_dim*num_frames,
204
+ output_embed_dim=output_embed_dim,
205
+ drop_cls_token=True,
206
+ Hp=14,
207
+ Wp=14),
208
+ decode_head=dict(
209
+ num_classes=len(CLASSES),
210
+ in_channels=output_embed_dim,
211
+ type='FCNHead',
212
+ in_index=-1,
213
+ channels=256,
214
+ num_convs=1,
215
+ concat_input=False,
216
+ dropout_ratio=0.1,
217
+ norm_cfg=dict(type='BN', requires_grad=True),
218
+ align_corners=False,
219
+ loss_decode=loss_func),
220
+ auxiliary_head=dict(
221
+ num_classes=len(CLASSES),
222
+ in_channels=output_embed_dim,
223
+ type='FCNHead',
224
+ in_index=-1,
225
+ channels=256,
226
+ num_convs=2,
227
+ concat_input=False,
228
+ dropout_ratio=0.1,
229
+ norm_cfg=dict(type='BN', requires_grad=True),
230
+ align_corners=False,
231
+ loss_decode=loss_func),
232
+ train_cfg=dict(),
233
+ test_cfg=dict(mode='slide', stride=(int(tile_size/2), int(tile_size/2)), crop_size=(tile_size, tile_size)))
234
+ auto_resume = False
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ pytorch==1.7.1
2
+ torchvision==0.8.2
3
+ openmim