Paolo-Fraccaro commited on
Commit
d54b8be
·
1 Parent(s): 8dd8672

fix no data issue

Browse files
Files changed (1) hide show
  1. app.py +8 -13
app.py CHANGED
@@ -137,28 +137,23 @@ def inference_on_file(target_image, model, custom_test_pipeline):
137
  result = inference_segmentor(model, target_image, custom_test_pipeline)
138
  print("Output has shape: " + str(result[0].shape))
139
 
140
- ##### get metadata mask
141
  mask = open_tiff(target_image)
142
- # rgb = mask[[2, 1, 0], :, :].transpose((1,2,0))
143
  rgb = mask[[5, 3, 2], :, :].transpose((1,2,0))
144
  meta = get_meta(target_image)
145
  mask = np.where(mask == meta['nodata'], 1, 0)
146
  mask = np.max(mask, axis=0)[None]
147
-
148
- result[0] = np.where(mask == 1, -1, result[0])
149
-
150
- ##### Save file to disk
151
- meta["count"] = 1
152
- meta["dtype"] = "int16"
153
- meta["compress"] = "lzw"
154
- meta["nodata"] = -1
155
- print('Saving output...')
156
- # write_tiff(result[0], output_image, meta)
157
  et = time.time()
158
  time_taken = np.round(et - st, 1)
159
  print(f'Inference completed in {str(time_taken)} seconds')
160
 
161
- return rgb, result[0][0]*255
 
162
 
163
  def process_test_pipeline(custom_test_pipeline, bands=None):
164
 
 
137
  result = inference_segmentor(model, target_image, custom_test_pipeline)
138
  print("Output has shape: " + str(result[0].shape))
139
 
140
+ # prep outputs
141
  mask = open_tiff(target_image)
 
142
  rgb = mask[[5, 3, 2], :, :].transpose((1,2,0))
143
  meta = get_meta(target_image)
144
  mask = np.where(mask == meta['nodata'], 1, 0)
145
  mask = np.max(mask, axis=0)[None]
146
+ rgb = np.where(mask.transpose((1,2,0)) == 1, 0, rgb)
147
+ rgb = np.where(rgb < 0, 0, rgb)
148
+ rgb = np.where(rgb > 1, 1, rgb)
149
+
150
+ prediction = np.where(mask == 1, 0, result[0]*255)
 
 
 
 
 
151
  et = time.time()
152
  time_taken = np.round(et - st, 1)
153
  print(f'Inference completed in {str(time_taken)} seconds')
154
 
155
+ return rgb, prediction[0]
156
+
157
 
158
  def process_test_pipeline(custom_test_pipeline, bands=None):
159