Paolo-Fraccaro commited on
Commit
2cfb5a4
1 Parent(s): 9db62e4

fix no data and channels different order

Browse files
Files changed (1) hide show
  1. app.py +23 -18
app.py CHANGED
@@ -163,37 +163,42 @@ def inference_segmentor(model, imgs, custom_test_pipeline=None):
163
  return result
164
 
165
 
 
 
 
 
 
 
 
 
 
 
166
  def inference_on_file(target_image, model, custom_test_pipeline):
167
 
168
  target_image = target_image.name
169
- # print(type(target_image))
170
-
171
- # output_image = target_image.replace('.tif', '_pred.tif')
172
  time_taken=-1
173
  st = time.time()
174
  print('Running inference...')
175
- result = inference_segmentor(model, target_image, custom_test_pipeline)
 
 
 
 
 
176
  print("Output has shape: " + str(result[0].shape))
177
 
178
  ##### get metadata mask
179
- mask = open_tiff(target_image)
180
- # rgb = mask[[2, 1, 0], :, :].transpose((1,2,0))
181
- rgb1 = stretch_rgb((mask[[2, 1, 0], :, :].transpose((1,2,0))/10000*255).astype(np.uint8))
182
- rgb2 = stretch_rgb((mask[[8, 7, 6], :, :].transpose((1,2,0))/10000*255).astype(np.uint8))
183
- rgb3 = stretch_rgb((mask[[14, 13, 12], :, :].transpose((1,2,0))/10000*255).astype(np.uint8))
184
  meta = get_meta(target_image)
185
- mask = np.where(mask == meta['nodata'], 1, 0)
186
  mask = np.max(mask, axis=0)[None]
 
 
 
 
187
 
188
- result[0] = np.where(mask == 1, -1, result[0])
189
 
190
- ##### Save file to disk
191
- meta["count"] = 1
192
- meta["dtype"] = "int16"
193
- meta["compress"] = "lzw"
194
- meta["nodata"] = -1
195
- print('Saving output...')
196
- # write_tiff(result[0], output_image, meta)
197
  et = time.time()
198
  time_taken = np.round(et - st, 1)
199
  print(f'Inference completed in {str(time_taken)} seconds')
 
163
  return result
164
 
165
 
166
+ def process_rgb(input, mask, indexes):
167
+
168
+
169
+ rgb = stretch_rgb((input[indexes, :, :].transpose((1,2,0))/10000*255).astype(np.uint8))
170
+ rgb = np.where(mask.transpose((1,2,0)) == 1, 0, rgb)
171
+ rgb = np.where(rgb < 0, 0, rgb)
172
+ rgb = np.where(rgb > 255, 255, rgb)
173
+
174
+ return rgb
175
+
176
  def inference_on_file(target_image, model, custom_test_pipeline):
177
 
178
  target_image = target_image.name
 
 
 
179
  time_taken=-1
180
  st = time.time()
181
  print('Running inference...')
182
+ try:
183
+ result = inference_segmentor(model, target_image, custom_test_pipeline)
184
+ except:
185
+ print('Error: Try different channels order.')
186
+ model.cfg.data.test.pipeline[0]['channels_last'] = True
187
+ result = inference_segmentor(model, target_image, custom_test_pipeline)
188
  print("Output has shape: " + str(result[0].shape))
189
 
190
  ##### get metadata mask
191
+ input = open_tiff(target_image)
 
 
 
 
192
  meta = get_meta(target_image)
193
+ mask = np.where(input == meta['nodata'], 1, 0)
194
  mask = np.max(mask, axis=0)[None]
195
+
196
+ rgb1 = process_rgb(input, mask, [2, 1, 0])
197
+ rgb2 = process_rgb(input, mask, [8, 7, 6])
198
+ rgb3 = process_rgb(input, mask, [14, 13, 12])
199
 
200
+ result[0] = np.where(mask == 1, 0, result[0])
201
 
 
 
 
 
 
 
 
202
  et = time.time()
203
  time_taken = np.round(et - st, 1)
204
  print(f'Inference completed in {str(time_taken)} seconds')