pokkiri commited on
Commit
c74a613
·
verified ·
1 Parent(s): 3a50655

Update app.py

Browse files

fixed app module

Files changed (1) hide show
  1. app.py +162 -40
app.py CHANGED
@@ -5,7 +5,16 @@ Provides a web interface for making predictions with StableResNet
5
  Author: najahpokkiri
6
  Date: 2025-05-17
7
  """
 
 
 
 
 
 
 
 
8
  import os
 
9
  import torch
10
  import numpy as np
11
  import gradio as gr
@@ -15,33 +24,48 @@ import matplotlib.pyplot as plt
15
  import matplotlib.colors as colors
16
  from PIL import Image
17
  import io
 
18
  from huggingface_hub import hf_hub_download
19
 
 
 
 
 
20
  # Import model architecture
21
  from model import StableResNet
22
 
23
  class BiomassPredictorApp:
24
- """Gradio app for biomass prediction"""
25
 
26
  def __init__(self, model_repo="pokkiri/biomass-model"):
 
27
  self.model = None
28
  self.package = None
 
29
  self.model_repo = model_repo
30
  self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
31
 
 
 
 
32
  # Load the model
33
  self.load_model()
34
 
35
  def load_model(self):
36
  """Load the model and preprocessing pipeline from HuggingFace Hub"""
37
  try:
38
- # Download files from HuggingFace
 
 
39
  model_path = hf_hub_download(repo_id=self.model_repo, filename="model.pt")
40
  package_path = hf_hub_download(repo_id=self.model_repo, filename="model_package.pkl")
41
 
42
  # Load package with metadata
43
  self.package = joblib.load(package_path)
 
 
44
  n_features = self.package['n_features']
 
45
 
46
  # Initialize model
47
  self.model = StableResNet(n_features=n_features)
@@ -49,17 +73,33 @@ class BiomassPredictorApp:
49
  self.model.to(self.device)
50
  self.model.eval()
51
 
52
- print(f"Model loaded successfully from {self.model_repo}")
53
- print(f"Number of features: {n_features}")
54
- print(f"Using device: {self.device}")
55
 
56
  return True
57
  except Exception as e:
58
- print(f"Error loading model: {e}")
 
 
59
  return False
60
 
 
 
 
 
 
 
 
 
 
 
 
61
  def predict_biomass(self, image_file, display_type="heatmap"):
62
  """Predict biomass from a satellite image"""
 
 
 
63
  try:
64
  # Create a temporary file to save the uploaded file
65
  with tempfile.NamedTemporaryFile(suffix='.tif', delete=False) as tmp_file:
@@ -67,10 +107,14 @@ class BiomassPredictorApp:
67
  with open(image_file.name, 'rb') as f:
68
  tmp_file.write(f.read())
69
 
 
 
 
 
70
  try:
71
  import rasterio
72
  except ImportError:
73
- return None, "Error: rasterio is required but not installed."
74
 
75
  # Open the image file
76
  with rasterio.open(tmp_path) as src:
@@ -79,29 +123,39 @@ class BiomassPredictorApp:
79
  transform = src.transform
80
  crs = src.crs
81
 
82
- # Check if number of bands matches expected features
83
  if image.shape[0] < self.package['n_features']:
84
- return None, f"Error: Image has {image.shape[0]} bands, but model expects at least {self.package['n_features']} features."
 
85
 
86
- print(f"Processing image: {height}x{width} pixels, {image.shape[0]} bands")
87
 
88
  # Process in chunks to avoid memory issues
89
- chunk_size = 1000
90
  predictions = np.zeros((height, width), dtype=np.float32)
91
 
92
  # Create mask for valid pixels (not NaN or Inf)
93
  valid_mask = np.all(np.isfinite(image), axis=0)
94
 
 
 
 
 
95
  # Process image in chunks
 
 
 
96
  for y_start in range(0, height, chunk_size):
97
  y_end = min(y_start + chunk_size, height)
98
 
99
  for x_start in range(0, width, chunk_size):
100
  x_end = min(x_start + chunk_size, width)
 
101
 
102
  # Get chunk mask
103
  chunk_mask = valid_mask[y_start:y_end, x_start:x_end]
104
  if not np.any(chunk_mask):
 
105
  continue
106
 
107
  # Extract valid pixels
@@ -124,42 +178,58 @@ class BiomassPredictorApp:
124
  batch_predictions = self.model(batch_tensor).cpu().numpy()
125
 
126
  # Convert from log scale if needed
127
- if self.package['use_log_transform']:
128
- batch_predictions = np.exp(batch_predictions) - self.package.get('epsilon', 1.0)
 
129
  batch_predictions = np.maximum(batch_predictions, 0) # Ensure non-negative
130
 
131
  # Insert predictions back into the image
132
  for idx, (i, j) in enumerate(zip(valid_y, valid_x)):
133
  predictions[y_start+i, x_start+j] = batch_predictions[idx]
134
-
135
- # Delete temporary file
136
- os.unlink(tmp_path)
137
 
138
  # Create visualization
 
139
  plt.figure(figsize=(12, 8))
140
 
141
  if display_type == "heatmap":
142
  # Create heatmap
143
- plt.imshow(predictions, cmap='viridis')
 
 
 
 
 
 
 
144
  plt.colorbar(label='Biomass (Mg/ha)')
145
  plt.title('Predicted Above-Ground Biomass')
 
146
 
147
  elif display_type == "rgb_overlay":
148
  # Create RGB + overlay
149
  if image.shape[0] >= 3:
150
  # Use first 3 bands as RGB
151
  rgb = image[[0, 1, 2]].transpose(1, 2, 0)
152
- rgb = np.clip((rgb - np.percentile(rgb, 2)) / (np.percentile(rgb, 98) - np.percentile(rgb, 2)), 0, 1)
153
 
154
- plt.imshow(rgb)
 
 
 
 
 
 
155
 
156
  # Create mask for overlay (where we have predictions)
157
- mask = ~np.isclose(predictions, 0)
158
  overlay = np.zeros((height, width, 4))
159
 
160
  # Create colormap for biomass
161
- norm = colors.Normalize(vmin=np.percentile(predictions[mask], 5),
162
- vmax=np.percentile(predictions[mask], 95))
 
 
163
  cmap = plt.cm.viridis
164
 
165
  # Apply colormap
@@ -170,42 +240,63 @@ class BiomassPredictorApp:
170
  plt.colorbar(plt.cm.ScalarMappable(norm=norm, cmap=cmap),
171
  label='Biomass (Mg/ha)')
172
  plt.title('Biomass Prediction Overlay')
 
173
  else:
174
- plt.imshow(predictions, cmap='viridis')
 
 
175
  plt.colorbar(label='Biomass (Mg/ha)')
176
  plt.title('Predicted Above-Ground Biomass')
 
177
 
178
  # Save figure to bytes buffer
179
  buf = io.BytesIO()
180
  plt.savefig(buf, format='png', dpi=150, bbox_inches='tight')
181
  buf.seek(0)
 
182
 
183
- # Create summary statistics
184
  valid_predictions = predictions[valid_mask]
185
  stats = {
186
  'Mean Biomass': f"{np.mean(valid_predictions):.2f} Mg/ha",
187
  'Median Biomass': f"{np.median(valid_predictions):.2f} Mg/ha",
188
  'Min Biomass': f"{np.min(valid_predictions):.2f} Mg/ha",
189
- 'Max Biomass': f"{np.max(valid_predictions):.2f} Mg/ha",
190
- 'Total Biomass': f"{np.sum(valid_predictions) * (transform[0] * transform[0]) / 10000:.2f} Mg",
191
- 'Area': f"{np.sum(valid_mask) * (transform[0] * transform[0]) / 10000:.2f} hectares"
192
  }
193
 
 
 
 
 
 
 
 
 
 
194
  # Format statistics as markdown
195
  stats_md = "### Biomass Statistics\n\n"
196
  stats_md += "| Metric | Value |\n|--------|-------|\n"
197
  for k, v in stats.items():
198
  stats_md += f"| {k} | {v} |\n"
199
 
200
- # Close the plot
201
- plt.close()
 
 
 
202
 
203
  # Return visualization and statistics
204
  return Image.open(buf), stats_md
205
 
206
  except Exception as e:
 
 
 
207
  import traceback
208
- return None, f"Error predicting biomass: {str(e)}\n\n{traceback.format_exc()}"
 
 
 
209
 
210
  def create_interface(self):
211
  """Create Gradio interface"""
@@ -220,7 +311,7 @@ class BiomassPredictorApp:
220
  """)
221
 
222
  with gr.Row():
223
- with gr.Column():
224
  input_image = gr.File(
225
  label="Upload Satellite Image (GeoTIFF)",
226
  file_types=[".tif", ".tiff"]
@@ -232,9 +323,9 @@ class BiomassPredictorApp:
232
  label="Display Type"
233
  )
234
 
235
- submit_btn = gr.Button("Generate Biomass Prediction")
236
 
237
- with gr.Column():
238
  output_image = gr.Image(
239
  label="Biomass Prediction Map",
240
  type="pil"
@@ -245,7 +336,7 @@ class BiomassPredictorApp:
245
  )
246
 
247
  with gr.Accordion("About", open=False):
248
- gr.Markdown(f"""
249
  ## About This Model
250
 
251
  This biomass prediction model uses the StableResNet architecture to predict above-ground biomass from satellite imagery.
@@ -255,9 +346,9 @@ class BiomassPredictorApp:
255
  - Architecture: StableResNet
256
  - Input: Multi-spectral satellite imagery
257
  - Output: Above-ground biomass (Mg/ha)
258
- - Creator: {pokkiri}
259
- - Date: {2025-05-17}
260
- - Model Repository: [{pokkiri/biomass-model}](https://huggingface.co/{pokkiri/biomass-model})
261
 
262
  ### How It Works
263
 
@@ -267,6 +358,27 @@ class BiomassPredictorApp:
267
  4. Results are visualized as a heatmap or RGB overlay
268
  """)
269
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
270
  submit_btn.click(
271
  fn=self.predict_biomass,
272
  inputs=[input_image, display_type],
@@ -277,9 +389,19 @@ class BiomassPredictorApp:
277
 
278
  def launch_app():
279
  """Launch the Gradio app"""
280
- app = BiomassPredictorApp()
281
- interface = app.create_interface()
282
- interface.launch()
 
 
 
 
 
 
 
 
 
 
283
 
284
  if __name__ == "__main__":
285
  launch_app()
 
5
  Author: najahpokkiri
6
  Date: 2025-05-17
7
  """
8
+ """
9
+ Biomass Prediction Gradio App
10
+ Author: najahpokkiri
11
+ Date: 2025-05-17
12
+
13
+ This app allows users to predict above-ground biomass from satellite imagery
14
+ using a trained StableResNet model.
15
+ """
16
  import os
17
+ import sys
18
  import torch
19
  import numpy as np
20
  import gradio as gr
 
24
  import matplotlib.colors as colors
25
  from PIL import Image
26
  import io
27
+ import logging
28
  from huggingface_hub import hf_hub_download
29
 
30
+ # Configure logger
31
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
32
+ logger = logging.getLogger(__name__)
33
+
34
  # Import model architecture
35
  from model import StableResNet
36
 
37
  class BiomassPredictorApp:
38
+ """Gradio app for biomass prediction from satellite imagery"""
39
 
40
  def __init__(self, model_repo="pokkiri/biomass-model"):
41
+ """Initialize the app with model repository information"""
42
  self.model = None
43
  self.package = None
44
+ self.feature_names = []
45
  self.model_repo = model_repo
46
  self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
47
 
48
+ # Cache for storing temporary files
49
+ self.temp_files = []
50
+
51
  # Load the model
52
  self.load_model()
53
 
54
  def load_model(self):
55
  """Load the model and preprocessing pipeline from HuggingFace Hub"""
56
  try:
57
+ logger.info(f"Loading model from {self.model_repo}")
58
+
59
+ # Download model files from HuggingFace
60
  model_path = hf_hub_download(repo_id=self.model_repo, filename="model.pt")
61
  package_path = hf_hub_download(repo_id=self.model_repo, filename="model_package.pkl")
62
 
63
  # Load package with metadata
64
  self.package = joblib.load(package_path)
65
+
66
+ # Extract information from package
67
  n_features = self.package['n_features']
68
+ self.feature_names = self.package.get('feature_names', [f"feature_{i}" for i in range(n_features)])
69
 
70
  # Initialize model
71
  self.model = StableResNet(n_features=n_features)
 
73
  self.model.to(self.device)
74
  self.model.eval()
75
 
76
+ logger.info(f"Model loaded successfully from {self.model_repo}")
77
+ logger.info(f"Number of features: {n_features}")
78
+ logger.info(f"Using device: {self.device}")
79
 
80
  return True
81
  except Exception as e:
82
+ logger.error(f"Error loading model: {e}")
83
+ import traceback
84
+ logger.error(traceback.format_exc())
85
  return False
86
 
87
+ def cleanup(self):
88
+ """Clean up temporary files"""
89
+ for tmp_path in self.temp_files:
90
+ try:
91
+ if os.path.exists(tmp_path):
92
+ os.unlink(tmp_path)
93
+ except Exception as e:
94
+ logger.warning(f"Failed to remove temporary file {tmp_path}: {e}")
95
+
96
+ self.temp_files = []
97
+
98
  def predict_biomass(self, image_file, display_type="heatmap"):
99
  """Predict biomass from a satellite image"""
100
+ if self.model is None:
101
+ return None, "Error: Model not loaded. Please check logs for details."
102
+
103
  try:
104
  # Create a temporary file to save the uploaded file
105
  with tempfile.NamedTemporaryFile(suffix='.tif', delete=False) as tmp_file:
 
107
  with open(image_file.name, 'rb') as f:
108
  tmp_file.write(f.read())
109
 
110
+ # Add to list for cleanup later
111
+ self.temp_files.append(tmp_path)
112
+
113
+ # Ensure rasterio is available
114
  try:
115
  import rasterio
116
  except ImportError:
117
+ return None, "Error: rasterio is required but not installed. Please install with: pip install rasterio"
118
 
119
  # Open the image file
120
  with rasterio.open(tmp_path) as src:
 
123
  transform = src.transform
124
  crs = src.crs
125
 
126
+ # Validate image dimensions
127
  if image.shape[0] < self.package['n_features']:
128
+ return None, (f"Error: Image has {image.shape[0]} bands, but model expects at least "
129
+ f"{self.package['n_features']} features.")
130
 
131
+ logger.info(f"Processing image: {height}x{width} pixels, {image.shape[0]} bands")
132
 
133
  # Process in chunks to avoid memory issues
134
+ chunk_size = min(1000, height, width) # Adjust chunk size for smaller images
135
  predictions = np.zeros((height, width), dtype=np.float32)
136
 
137
  # Create mask for valid pixels (not NaN or Inf)
138
  valid_mask = np.all(np.isfinite(image), axis=0)
139
 
140
+ # Show progress indicator
141
+ progress_text = f"Processing {height}x{width} image..."
142
+ logger.info(progress_text)
143
+
144
  # Process image in chunks
145
+ total_chunks = ((height + chunk_size - 1) // chunk_size) * ((width + chunk_size - 1) // chunk_size)
146
+ chunk_count = 0
147
+
148
  for y_start in range(0, height, chunk_size):
149
  y_end = min(y_start + chunk_size, height)
150
 
151
  for x_start in range(0, width, chunk_size):
152
  x_end = min(x_start + chunk_size, width)
153
+ chunk_count += 1
154
 
155
  # Get chunk mask
156
  chunk_mask = valid_mask[y_start:y_end, x_start:x_end]
157
  if not np.any(chunk_mask):
158
+ logger.info(f"Skipping chunk {chunk_count}/{total_chunks} (no valid pixels)")
159
  continue
160
 
161
  # Extract valid pixels
 
178
  batch_predictions = self.model(batch_tensor).cpu().numpy()
179
 
180
  # Convert from log scale if needed
181
+ if self.package.get('use_log_transform', False):
182
+ epsilon = self.package.get('epsilon', 1.0)
183
+ batch_predictions = np.exp(batch_predictions) - epsilon
184
  batch_predictions = np.maximum(batch_predictions, 0) # Ensure non-negative
185
 
186
  # Insert predictions back into the image
187
  for idx, (i, j) in enumerate(zip(valid_y, valid_x)):
188
  predictions[y_start+i, x_start+j] = batch_predictions[idx]
189
+
190
+ logger.info(f"Processed chunk {chunk_count}/{total_chunks}")
 
191
 
192
  # Create visualization
193
+ logger.info("Creating visualization...")
194
  plt.figure(figsize=(12, 8))
195
 
196
  if display_type == "heatmap":
197
  # Create heatmap
198
+ # Use masked array for better visualization
199
+ masked_predictions = np.ma.masked_where(~valid_mask, predictions)
200
+
201
+ # Set min/max values based on percentiles for better contrast
202
+ vmin = np.percentile(predictions[valid_mask], 1)
203
+ vmax = np.percentile(predictions[valid_mask], 99)
204
+
205
+ plt.imshow(masked_predictions, cmap='viridis', vmin=vmin, vmax=vmax)
206
  plt.colorbar(label='Biomass (Mg/ha)')
207
  plt.title('Predicted Above-Ground Biomass')
208
+ plt.axis('off') # Hide axes for cleaner visualization
209
 
210
  elif display_type == "rgb_overlay":
211
  # Create RGB + overlay
212
  if image.shape[0] >= 3:
213
  # Use first 3 bands as RGB
214
  rgb = image[[0, 1, 2]].transpose(1, 2, 0)
 
215
 
216
+ # Enhance contrast with percentile-based normalization
217
+ p2 = np.percentile(rgb[np.isfinite(rgb)], 2)
218
+ p98 = np.percentile(rgb[np.isfinite(rgb)], 98)
219
+ rgb_norm = np.clip((rgb - p2) / (p98 - p2), 0, 1)
220
+
221
+ # Display RGB image
222
+ plt.imshow(rgb_norm)
223
 
224
  # Create mask for overlay (where we have predictions)
225
+ mask = valid_mask & (~np.isclose(predictions, 0))
226
  overlay = np.zeros((height, width, 4))
227
 
228
  # Create colormap for biomass
229
+ norm = colors.Normalize(
230
+ vmin=np.percentile(predictions[mask], 5),
231
+ vmax=np.percentile(predictions[mask], 95)
232
+ )
233
  cmap = plt.cm.viridis
234
 
235
  # Apply colormap
 
240
  plt.colorbar(plt.cm.ScalarMappable(norm=norm, cmap=cmap),
241
  label='Biomass (Mg/ha)')
242
  plt.title('Biomass Prediction Overlay')
243
+ plt.axis('off')
244
  else:
245
+ # Fallback to regular heatmap if not enough bands for RGB
246
+ masked_predictions = np.ma.masked_where(~valid_mask, predictions)
247
+ plt.imshow(masked_predictions, cmap='viridis')
248
  plt.colorbar(label='Biomass (Mg/ha)')
249
  plt.title('Predicted Above-Ground Biomass')
250
+ plt.axis('off')
251
 
252
  # Save figure to bytes buffer
253
  buf = io.BytesIO()
254
  plt.savefig(buf, format='png', dpi=150, bbox_inches='tight')
255
  buf.seek(0)
256
+ plt.close()
257
 
258
+ # Calculate summary statistics
259
  valid_predictions = predictions[valid_mask]
260
  stats = {
261
  'Mean Biomass': f"{np.mean(valid_predictions):.2f} Mg/ha",
262
  'Median Biomass': f"{np.median(valid_predictions):.2f} Mg/ha",
263
  'Min Biomass': f"{np.min(valid_predictions):.2f} Mg/ha",
264
+ 'Max Biomass': f"{np.max(valid_predictions):.2f} Mg/ha"
 
 
265
  }
266
 
267
+ # Add area and total biomass if transform is available
268
+ if transform is not None:
269
+ pixel_area_m2 = abs(transform[0] * transform[4]) # Assuming square pixels
270
+ total_biomass = np.sum(valid_predictions) * (pixel_area_m2 / 10000) # Convert to hectares
271
+ area_hectares = np.sum(valid_mask) * (pixel_area_m2 / 10000)
272
+
273
+ stats['Total Biomass'] = f"{total_biomass:.2f} Mg"
274
+ stats['Area'] = f"{area_hectares:.2f} hectares"
275
+
276
  # Format statistics as markdown
277
  stats_md = "### Biomass Statistics\n\n"
278
  stats_md += "| Metric | Value |\n|--------|-------|\n"
279
  for k, v in stats.items():
280
  stats_md += f"| {k} | {v} |\n"
281
 
282
+ # Add processing info
283
+ stats_md += f"\n\n*Processed {np.sum(valid_mask):,} valid pixels*"
284
+
285
+ # Cleanup temporary files
286
+ self.cleanup()
287
 
288
  # Return visualization and statistics
289
  return Image.open(buf), stats_md
290
 
291
  except Exception as e:
292
+ # Ensure cleanup even on error
293
+ self.cleanup()
294
+
295
  import traceback
296
+ logger.error(f"Error predicting biomass: {e}")
297
+ logger.error(traceback.format_exc())
298
+
299
+ return None, f"Error predicting biomass: {str(e)}\n\nPlease check logs for details."
300
 
301
  def create_interface(self):
302
  """Create Gradio interface"""
 
311
  """)
312
 
313
  with gr.Row():
314
+ with gr.Column(scale=1):
315
  input_image = gr.File(
316
  label="Upload Satellite Image (GeoTIFF)",
317
  file_types=[".tif", ".tiff"]
 
323
  label="Display Type"
324
  )
325
 
326
+ submit_btn = gr.Button("Generate Biomass Prediction", variant="primary")
327
 
328
+ with gr.Column(scale=2):
329
  output_image = gr.Image(
330
  label="Biomass Prediction Map",
331
  type="pil"
 
336
  )
337
 
338
  with gr.Accordion("About", open=False):
339
+ gr.Markdown("""
340
  ## About This Model
341
 
342
  This biomass prediction model uses the StableResNet architecture to predict above-ground biomass from satellite imagery.
 
346
  - Architecture: StableResNet
347
  - Input: Multi-spectral satellite imagery
348
  - Output: Above-ground biomass (Mg/ha)
349
+ - Creator: najahpokkiri
350
+ - Date: 2025-05-17
351
+ - Model Repository: [pokkiri/biomass-model](https://huggingface.co/pokkiri/biomass-model)
352
 
353
  ### How It Works
354
 
 
358
  4. Results are visualized as a heatmap or RGB overlay
359
  """)
360
 
361
+ with gr.Accordion("Examples", open=False):
362
+ gr.Markdown("""
363
+ ### Example Data
364
+
365
+ To try the model, you can use sample GeoTIFF files with the following characteristics:
366
+
367
+ - Multi-band satellite imagery (Sentinel-2, Landsat, etc.)
368
+ - Contains bands in the proper order (see documentation)
369
+ - Images should be relatively small (< 2000x2000 pixels) for faster processing
370
+
371
+ You can find sample data at:
372
+ - [Earth Explorer](https://earthexplorer.usgs.gov/)
373
+ - [Copernicus Open Access Hub](https://scihub.copernicus.eu/)
374
+ - [Planetary Computer](https://planetarycomputer.microsoft.com/)
375
+ """)
376
+
377
+ # Add a warning if model failed to load
378
+ if self.model is None:
379
+ gr.Warning("⚠️ Model failed to load. The app may not work correctly. Check logs for details.")
380
+
381
+ # Connect the submit button
382
  submit_btn.click(
383
  fn=self.predict_biomass,
384
  inputs=[input_image, display_type],
 
389
 
390
  def launch_app():
391
  """Launch the Gradio app"""
392
+ try:
393
+ # Create app instance
394
+ app = BiomassPredictorApp()
395
+
396
+ # Create interface
397
+ interface = app.create_interface()
398
+
399
+ # Launch interface
400
+ interface.launch(share=True)
401
+ except Exception as e:
402
+ logger.error(f"Error launching app: {e}")
403
+ import traceback
404
+ logger.error(traceback.format_exc())
405
 
406
  if __name__ == "__main__":
407
  launch_app()