Files changed (1) hide show
  1. app.py +0 -270
app.py DELETED
@@ -1,270 +0,0 @@
1
- ######### pull files
2
- import os
3
- from huggingface_hub import hf_hub_download
4
- config_path=hf_hub_download(repo_id="ibm-nasa-geospatial/Prithvi-100M-multi-temporal-crop-classification",
5
- filename="multi_temporal_crop_classification_Prithvi_100M.py",
6
- token=os.environ.get("token"))
7
- ckpt=hf_hub_download(repo_id="ibm-nasa-geospatial/Prithvi-100M-multi-temporal-crop-classification",
8
- filename='multi_temporal_crop_classification_Prithvi_100M.pth',
9
- token=os.environ.get("token"))
10
- ##########
11
- import argparse
12
- from mmcv import Config
13
-
14
- from mmseg.models import build_segmentor
15
-
16
- from mmseg.datasets.pipelines import Compose, LoadImageFromFile
17
-
18
- import rasterio
19
- import torch
20
-
21
- from mmseg.apis import init_segmentor
22
-
23
- from mmcv.parallel import collate, scatter
24
-
25
- import numpy as np
26
- import glob
27
- import os
28
-
29
- import time
30
-
31
- import numpy as np
32
- import gradio as gr
33
- from functools import partial
34
-
35
- import pdb
36
-
37
- import matplotlib.pyplot as plt
38
-
39
- from skimage import exposure
40
-
41
- cdl_color_map = [{'value': 1, 'label': 'Natural vegetation', 'rgb': (233,255,190)},
42
- {'value': 2, 'label': 'Forest', 'rgb': (149,206,147)},
43
- {'value': 3, 'label': 'Corn', 'rgb': (255,212,0)},
44
- {'value': 4, 'label': 'Soybeans', 'rgb': (38,115,0)},
45
- {'value': 5, 'label': 'Wetlands', 'rgb': (128,179,179)},
46
- {'value': 6, 'label': 'Developed/Barren', 'rgb': (156,156,156)},
47
- {'value': 7, 'label': 'Open Water', 'rgb': (77,112,163)},
48
- {'value': 8, 'label': 'Winter Wheat', 'rgb': (168,112,0)},
49
- {'value': 9, 'label': 'Alfalfa', 'rgb': (255,168,227)},
50
- {'value': 10, 'label': 'Fallow/Idle cropland', 'rgb': (191,191,122)},
51
- {'value': 11, 'label': 'Cotton', 'rgb':(255,38,38)},
52
- {'value': 12, 'label': 'Sorghum', 'rgb':(255,158,15)},
53
- {'value': 13, 'label': 'Other', 'rgb':(0,175,77)}]
54
-
55
-
56
- def apply_color_map(rgb, color_map=cdl_color_map):
57
-
58
-
59
- rgb_mapped = rgb.copy()
60
-
61
- for map_tmp in cdl_color_map:
62
-
63
- for i in range(3):
64
- rgb_mapped[i] = np.where((rgb[0] == map_tmp['value']) & (rgb[1] == map_tmp['value']) & (rgb[2] == map_tmp['value']), map_tmp['rgb'][i], rgb_mapped[i])
65
-
66
- return rgb_mapped
67
-
68
-
69
- def stretch_rgb(rgb):
70
-
71
- ls_pct=0
72
- pLow, pHigh = np.percentile(rgb[~np.isnan(rgb)], (ls_pct,100-ls_pct))
73
- img_rescale = exposure.rescale_intensity(rgb, in_range=(pLow,pHigh))
74
-
75
- return img_rescale
76
-
77
- def open_tiff(fname):
78
-
79
- with rasterio.open(fname, "r") as src:
80
-
81
- data = src.read()
82
-
83
- return data
84
-
85
- def write_tiff(img_wrt, filename, metadata):
86
-
87
- """
88
- It writes a raster image to file.
89
-
90
- :param img_wrt: numpy array containing the data (can be 2D for single band or 3D for multiple bands)
91
- :param filename: file path to the output file
92
- :param metadata: metadata to use to write the raster to disk
93
- :return:
94
- """
95
-
96
- with rasterio.open(filename, "w", **metadata) as dest:
97
-
98
- if len(img_wrt.shape) == 2:
99
-
100
- img_wrt = img_wrt[None]
101
-
102
- for i in range(img_wrt.shape[0]):
103
- dest.write(img_wrt[i, :, :], i + 1)
104
-
105
- return filename
106
-
107
-
108
- def get_meta(fname):
109
-
110
- with rasterio.open(fname, "r") as src:
111
-
112
- meta = src.meta
113
-
114
- return meta
115
-
116
- def preprocess_example(example_list):
117
-
118
- example_list = [os.path.join(os.path.abspath(''), x) for x in example_list]
119
-
120
- return example_list
121
-
122
-
123
- def inference_segmentor(model, imgs, custom_test_pipeline=None):
124
- """Inference image(s) with the segmentor.
125
-
126
- Args:
127
- model (nn.Module): The loaded segmentor.
128
- imgs (str/ndarray or list[str/ndarray]): Either image files or loaded
129
- images.
130
-
131
- Returns:
132
- (list[Tensor]): The segmentation result.
133
- """
134
- cfg = model.cfg
135
- device = next(model.parameters()).device # model device
136
- # build the data pipeline
137
- test_pipeline = [LoadImageFromFile()] + cfg.data.test.pipeline[1:] if custom_test_pipeline == None else custom_test_pipeline
138
- test_pipeline = Compose(test_pipeline)
139
- # prepare data
140
- data = []
141
- imgs = imgs if isinstance(imgs, list) else [imgs]
142
- for img in imgs:
143
- img_data = {'img_info': {'filename': img}}
144
- img_data = test_pipeline(img_data)
145
- data.append(img_data)
146
- # print(data.shape)
147
-
148
- data = collate(data, samples_per_gpu=len(imgs))
149
- if next(model.parameters()).is_cuda:
150
- # data = collate(data, samples_per_gpu=len(imgs))
151
- # scatter to specified GPU
152
- data = scatter(data, [device])[0]
153
- else:
154
- # img_metas = scatter(data['img_metas'],'cpu')
155
- # data['img_metas'] = [i.data[0] for i in data['img_metas']]
156
-
157
- img_metas = data['img_metas'].data[0]
158
- img = data['img']
159
- data = {'img': img, 'img_metas':img_metas}
160
-
161
- with torch.no_grad():
162
- result = model(return_loss=False, rescale=True, **data)
163
- return result
164
-
165
-
166
- def process_rgb(input, mask, indexes):
167
-
168
-
169
- rgb = stretch_rgb((input[indexes, :, :].transpose((1,2,0))/10000*255).astype(np.uint8))
170
- rgb = np.where(mask.transpose((1,2,0)) == 1, 0, rgb)
171
- rgb = np.where(rgb < 0, 0, rgb)
172
- rgb = np.where(rgb > 255, 255, rgb)
173
-
174
- return rgb
175
-
176
- def inference_on_file(target_image, model, custom_test_pipeline):
177
-
178
- target_image = target_image.name
179
- time_taken=-1
180
- st = time.time()
181
- print('Running inference...')
182
- result = inference_segmentor(model, target_image, custom_test_pipeline)
183
- print("Output has shape: " + str(result[0].shape))
184
-
185
- ##### get metadata mask
186
- input = open_tiff(target_image)
187
- meta = get_meta(target_image)
188
- mask = np.where(input == meta['nodata'], 1, 0)
189
- mask = np.max(mask, axis=0)[None]
190
-
191
- rgb1 = process_rgb(input, mask, [2, 1, 0])
192
- rgb2 = process_rgb(input, mask, [8, 7, 6])
193
- rgb3 = process_rgb(input, mask, [14, 13, 12])
194
-
195
- result[0] = np.where(mask == 1, 0, result[0])
196
-
197
- et = time.time()
198
- time_taken = np.round(et - st, 1)
199
- print(f'Inference completed in {str(time_taken)} seconds')
200
-
201
- output=result[0][0] + 1
202
- output = np.vstack([output[None], output[None], output[None]]).astype(np.uint8)
203
- output=apply_color_map(output).transpose((1,2,0))
204
-
205
- return rgb1,rgb2,rgb3,output
206
-
207
- def process_test_pipeline(custom_test_pipeline, bands=None):
208
-
209
- # change extracted bands if necessary
210
- if bands is not None:
211
-
212
- extract_index = [i for i, x in enumerate(custom_test_pipeline) if x['type'] == 'BandsExtract' ]
213
-
214
- if len(extract_index) > 0:
215
-
216
- custom_test_pipeline[extract_index[0]]['bands'] = eval(bands)
217
-
218
- collect_index = [i for i, x in enumerate(custom_test_pipeline) if x['type'].find('Collect') > -1]
219
-
220
- # adapt collected keys if necessary
221
- if len(collect_index) > 0:
222
-
223
- keys = ['img_info', 'filename', 'ori_filename', 'img', 'img_shape', 'ori_shape', 'pad_shape', 'scale_factor', 'img_norm_cfg']
224
- custom_test_pipeline[collect_index[0]]['meta_keys'] = keys
225
-
226
- return custom_test_pipeline
227
-
228
- config = Config.fromfile(config_path)
229
- config.model.backbone.pretrained=None
230
- model = init_segmentor(config, ckpt, device='cpu')
231
- custom_test_pipeline=process_test_pipeline(model.cfg.data.test.pipeline, None)
232
-
233
- func = partial(inference_on_file, model=model, custom_test_pipeline=custom_test_pipeline)
234
-
235
- with gr.Blocks() as demo:
236
-
237
- gr.Markdown(value='# Prithvi multi temporal crop classification')
238
- gr.Markdown(value='''Prithvi is a first-of-its-kind temporal Vision transformer pretrained by the IBM and NASA team on continental US Harmonised Landsat Sentinel 2 (HLS) data. This demo showcases how the model was finetuned to classify crop and other land use categories using multi temporal data. More detailes can be found [here](https://huggingface.co/ibm-nasa-geospatial/Prithvi-100M-multi-temporal-crop-classification).\n
239
- The user needs to provide an HLS geotiff image, including 18 bands for 3 time-step, and each time-step includes the channels described above (Blue, Green, Red, Narrow NIR, SWIR, SWIR 2) in order.
240
- ''')
241
- with gr.Row():
242
- with gr.Column():
243
- inp = gr.File()
244
- btn = gr.Button("Submit")
245
-
246
- with gr.Row():
247
- inp1=gr.Image(image_mode='RGB', scale=10, label='T1')
248
- inp2=gr.Image(image_mode='RGB', scale=10, label='T2')
249
- inp3=gr.Image(image_mode='RGB', scale=10, label='T3')
250
- out = gr.Image(image_mode='RGB', scale=10, label='Model prediction')
251
- # gr.Image(value='Legend.png', image_mode='RGB', scale=2, show_label=False)
252
-
253
- btn.click(fn=func, inputs=inp, outputs=[inp1, inp2, inp3, out])
254
-
255
- with gr.Row():
256
- with gr.Column():
257
- gr.Examples(examples=["chip_102_345_merged.tif",
258
- "chip_104_104_merged.tif",
259
- "chip_109_421_merged.tif"],
260
- inputs=inp,
261
- outputs=[inp1, inp2, inp3, out],
262
- preprocess=preprocess_example,
263
- fn=func,
264
- cache_examples=True)
265
- with gr.Column():
266
- gr.Markdown(value='### Model prediction legend')
267
- gr.Image(value='Legend.png', image_mode='RGB', show_label=False)
268
-
269
-
270
- demo.launch()