pokkiri commited on
Commit
a49479d
·
verified ·
1 Parent(s): 0d67e1c

Update feature_engineering.py

Browse files
Files changed (1) hide show
  1. feature_engineering.py +377 -670
feature_engineering.py CHANGED
@@ -1,709 +1,416 @@
1
- def create_interface(self):
2
- """Create Gradio interface with sample image thumbnails"""
3
- # Generate thumbnails for sample images
4
- sample_thumbnails = {}
5
- for name, path in self.sample_images.items():
6
- if os.path.exists(path):
7
- thumbnail = self.create_thumbnail(path)
8
- if thumbnail:
9
- sample_thumbnails[name] = Image.open(thumbnail)
10
- else:
11
- logger.warning(f"Sample image not found: {path}")
12
-
13
- with gr.Blocks(title="Biomass Prediction Model") as interface:
14
- gr.Markdown("# Above-Ground Biomass Prediction")
15
- gr.Markdown("""
16
- Upload a multi-band satellite image to predict above-ground biomass (AGB) across the landscape.
17
-
18
- **Requirements:**
19
- - Image must be a GeoTIFF with spectral bands
20
- - For best results, image should contain at least 3 bands
21
- """)
22
-
23
- with gr.Row():
24
- with gr.Column(scale=1):
25
- input_image = gr.File(
26
- label="Upload Satellite Image (GeoTIFF)",
27
- file_types=[".tif", ".tiff"]
28
- )
29
-
30
- # Sample images section
31
- gr.Markdown("### Sample Images")
32
-
33
- # Sample buttons container
34
- sample_buttons = []
35
-
36
- # First row - sample thumbnails side by side horizontally
37
- with gr.Row():
38
- for name, thumbnail in sample_thumbnails.items():
39
- with gr.Column():
40
- gr.Image(
41
- value=thumbnail,
42
- label=name.replace("input_", "Input ").replace("chip_", "Chip "),
43
- show_download_button=False,
44
- height=180
45
- )
46
-
47
- # Second row - buttons side by side horizontally, matching the thumbnails above
48
- with gr.Row():
49
- for name, _ in sample_thumbnails.items():
50
- with gr.Column():
51
- sample_btn = gr.Button(
52
- f"Use {name.replace('input_', 'Input ').replace('chip_', 'Chip ')}",
53
- variant="secondary",
54
- size="lg"
55
- )
56
- sample_buttons.append((sample_btn, name))
57
-
58
- # Generate button at the bottom
59
- generate_btn = gr.Button("Generate Biomass Prediction", variant="primary", size="lg")
60
-
61
- with gr.Column(scale=2):
62
- output_image = gr.Image(
63
- label="Biomass Prediction Map",
64
- type="pil"
65
- )
66
-
67
- output_stats = gr.Markdown(
68
- label="Statistics"
69
- )
70
-
71
- with gr.Accordion("About", open=False):
72
- gr.Markdown("""
73
- ## About This Model
74
-
75
- This biomass prediction model uses the StableResNet architecture to predict above-ground biomass from satellite imagery.
76
-
77
- ### Model Details
78
-
79
- - Architecture: StableResNet
80
- - Input: Multi-spectral satellite imagery
81
- - Output: Above-ground biomass (Mg/ha)
82
- - Creator: vertify.earth
83
- - Date: 2025-05-19
84
-
85
- ### Improvements in This Version
86
-
87
- - Added calibration factor to match full-tile inference values
88
- - Improved chunk processing with overlap to reduce edge artifacts
89
- - Enhanced feature calculation for better results
90
- - Optimized visualization to show the full range of biomass values
91
- """)
92
-
93
- # Add a warning if model failed to load
94
- if self.model is None:
95
- gr.Warning("⚠️ Model failed to load. The app may not work correctly. Check logs for details.")
96
-
97
- # Connect the process button
98
- generate_btn.click(
99
- fn=self.predict_biomass,
100
- inputs=[input_image],
101
- outputs=[output_image, output_stats]
102
- )
103
-
104
- # Connect the sample buttons
105
- for button, name in sample_buttons:
106
- button.click(
107
- fn=lambda path=self.sample_images[name]: self.predict_biomass(path),
108
- inputs=[],
109
- outputs=[output_image, output_stats]
110
- )
111
-
112
- return interface
113
-
114
- def launch_app():
115
- """Launch the Gradio app"""
116
- try:
117
- # Create app instance
118
- app = BiomassPredictorApp()
119
-
120
- # Create interface
121
- interface = app.create_interface()
122
-
123
- # Launch interface
124
- interface.launch()
125
- except Exception as e:
126
- logger.error(f"Error launching app: {e}")
127
- logger.error(traceback.format_exc())
128
 
129
- if __name__ == "__main__":
130
- launch_app()"""
131
- Biomass Prediction Gradio App with Two Sample Images and RGB Comparison
132
  Author: najahpokkiri
133
  Date: 2025-05-19
134
-
135
- Updated with sample image thumbnails and always-on RGB comparison.
136
  """
137
- import os
138
- import sys
139
- import torch
140
  import numpy as np
141
- import gradio as gr
142
- import joblib
143
- import tempfile
144
- import matplotlib.pyplot as plt
145
- import matplotlib.colors as colors
146
- from PIL import Image
147
- import io
148
  import logging
149
- from huggingface_hub import hf_hub_download
150
- import rasterio
151
 
152
  # Configure logger
153
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
154
  logger = logging.getLogger(__name__)
155
 
156
- # Import model architecture
157
- from model import StableResNet
 
 
 
 
 
 
158
 
159
- # Define a placeholder for feature engineering if not available
160
- def extract_all_features(image):
161
- """
162
- Extract all 99 features from satellite bands.
163
- Placeholder function - in production, use the actual feature_engineering module.
164
- """
165
- # Get image dimensions
166
- n_bands, height, width = image.shape
167
-
168
- # Create a valid mask (non-NaN pixels)
169
- valid_mask = np.all(np.isfinite(image), axis=0)
170
-
171
- # Get valid pixel coordinates
172
- valid_y, valid_x = np.where(valid_mask)
173
- n_valid = len(valid_y)
174
 
175
- # Create a feature matrix (placeholder)
176
- # In a real scenario, these would be spectral indices, texture features, etc.
177
- # For now, we'll just use the original bands and pad to 99 features
178
 
179
- # Original bands for each valid pixel
180
- feature_matrix = np.zeros((n_valid, 99), dtype=np.float32)
 
 
181
 
182
- # Fill in the available band values
183
- for i in range(n_valid):
184
- y, x = valid_y[i], valid_x[i]
185
- # Copy available bands
186
- for b in range(min(n_bands, 99)):
187
- feature_matrix[i, b] = image[b, y, x]
188
 
189
- # Create feature names
190
- generated_features = [f"Band_{i+1}" for i in range(99)]
 
191
 
192
- return feature_matrix, valid_mask, generated_features
193
-
194
- class BiomassPredictorApp:
195
- """Gradio app for biomass prediction from satellite imagery"""
196
-
197
- def __init__(self, model_repo="pokkiri/biomass-model"):
198
- """Initialize the app with model repository information"""
199
- self.model = None
200
- self.package = None
201
- self.feature_names = []
202
- self.model_repo = model_repo
203
- self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
204
-
205
- # Sample image paths
206
- self.sample_images = {
207
- "input_chip_1": "input_chip_1.tif",
208
- "input_chip_2": "input_chip_2.tif"
209
- }
210
-
211
- # Cache for storing temporary files
212
- self.temp_files = []
213
 
214
- # Load the model
215
- self.load_model()
216
-
217
- def load_model(self):
218
- """Load the model and preprocessing pipeline"""
219
- try:
220
- logger.info(f"Loading model from {self.model_repo}")
221
-
222
- # Download model files from HuggingFace or use local files
223
- try:
224
- model_path = hf_hub_download(repo_id=self.model_repo, filename="model.pt")
225
- package_path = hf_hub_download(repo_id=self.model_repo, filename="model_package.pkl")
226
- except Exception as e:
227
- logger.warning(f"Failed to download from HuggingFace: {e}")
228
- # Fallback to local files
229
- model_path = "model.pt"
230
- package_path = "model_package.pkl"
231
 
232
- # Try to load package with metadata
233
- try:
234
- logger.info(f"Loading package from {package_path}")
235
- self.package = joblib.load(package_path)
236
- logger.info("Successfully loaded model package")
237
 
238
- # Extract information from package
239
- n_features = self.package['n_features']
240
- self.feature_names = self.package.get('feature_names', [f"feature_{i}" for i in range(n_features)])
241
 
242
- logger.info(f"Package keys: {list(self.package.keys())}")
243
- logger.info(f"Model expects {n_features} features")
244
- except Exception as e:
245
- logger.error(f"Error loading package file: {e}")
246
- # Fallback to default values
247
- n_features = 99 # We know there are 99 features
248
- self.feature_names = [f"feature_{i}" for i in range(n_features)]
249
 
250
- # Create a minimal package with essential components
251
- self.package = {
252
- 'n_features': n_features,
253
- 'use_log_transform': True,
254
- 'epsilon': 1.0,
255
- 'scaler': None # Will handle the None case in prediction
256
- }
257
-
258
- # Initialize model
259
- self.model = StableResNet(n_features=n_features)
260
- self.model.load_state_dict(torch.load(model_path, map_location=self.device))
261
- self.model.to(self.device)
262
- self.model.eval()
263
-
264
- logger.info(f"Model loaded successfully")
265
- logger.info(f"Number of features: {n_features}")
266
- logger.info(f"Using device: {self.device}")
267
-
268
- return True
269
- except Exception as e:
270
- logger.error(f"Error loading model: {e}")
271
- import traceback
272
- logger.error(traceback.format_exc())
273
- return False
274
-
275
- def cleanup(self):
276
- """Clean up temporary files"""
277
- for tmp_path in self.temp_files:
278
- try:
279
- if os.path.exists(tmp_path):
280
- os.unlink(tmp_path)
281
- except Exception as e:
282
- logger.warning(f"Failed to remove temporary file {tmp_path}: {e}")
283
 
284
- self.temp_files = []
285
-
286
- def create_thumbnail(self, image_path, max_size=(200, 200), output_format="PNG"):
287
- """Create a thumbnail image from a GeoTIFF"""
288
- try:
289
- if not os.path.exists(image_path):
290
- logger.warning(f"Image file not found: {image_path}")
291
- return None
292
-
293
- # Open the GeoTIFF
294
- with rasterio.open(image_path) as src:
295
- # Read data with RGB bands if available
296
- if src.count >= 3:
297
- # Use first three bands as RGB
298
- rgb_data = src.read([1, 2, 3])
299
-
300
- # Transpose from (bands, height, width) to (height, width, bands)
301
- rgb_data = np.transpose(rgb_data, (1, 2, 0))
302
-
303
- # Normalize to 0-255 range
304
- rgb_data = np.clip(rgb_data, 0, None) # Clip negative values
305
- for i in range(3):
306
- p2 = np.percentile(rgb_data[:,:,i], 2)
307
- p98 = np.percentile(rgb_data[:,:,i], 98)
308
- if p98 > p2:
309
- rgb_data[:,:,i] = np.clip((rgb_data[:,:,i] - p2) / (p98 - p2) * 255, 0, 255)
310
- else:
311
- rgb_data[:,:,i] = np.clip(rgb_data[:,:,i] / (rgb_data[:,:,i].max() or 1) * 255, 0, 255)
312
-
313
- # Convert to uint8
314
- rgb_data = rgb_data.astype(np.uint8)
315
-
316
- # Create PIL image
317
- img = Image.fromarray(rgb_data)
318
- else:
319
- # Use first band as grayscale
320
- gray_data = src.read(1)
321
-
322
- # Normalize to 0-255 range
323
- p2 = np.percentile(gray_data, 2)
324
- p98 = np.percentile(gray_data, 98)
325
- if p98 > p2:
326
- gray_data = np.clip((gray_data - p2) / (p98 - p2) * 255, 0, 255)
327
- else:
328
- gray_data = np.clip(gray_data / (gray_data.max() or 1) * 255, 0, 255)
329
-
330
- # Convert to uint8
331
- gray_data = gray_data.astype(np.uint8)
332
-
333
- # Create PIL image
334
- img = Image.fromarray(gray_data, mode='L')
335
-
336
- # Resize to thumbnail
337
- img.thumbnail(max_size)
338
-
339
- # Save to bytes buffer
340
- buf = io.BytesIO()
341
- img.save(buf, format=output_format)
342
- buf.seek(0)
343
 
344
- return buf
345
- except Exception as e:
346
- logger.error(f"Error creating thumbnail: {e}")
347
- return None
348
-
349
- def predict_biomass(self, image_file):
350
- """Predict biomass from a satellite image with RGB comparison"""
351
- if self.model is None:
352
- return None, "Error: Model not loaded. Please check logs for details."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
353
 
354
- if image_file is None:
355
- return None, "Error: No file uploaded. Please upload a GeoTIFF file or use one of the sample images."
 
356
 
357
- try:
358
- # Check if we're using a sample image (string path) or an uploaded file
359
- if isinstance(image_file, str):
360
- logger.info(f"Using sample image: {image_file}")
361
- tmp_path = image_file # Use the sample path directly
362
- cleanup_tmp = False # Don't delete the sample file
363
- else:
364
- # Create a temporary file to save the uploaded file
365
- with tempfile.NamedTemporaryFile(suffix='.tif', delete=False) as tmp_file:
366
- tmp_path = tmp_file.name
367
- with open(image_file.name, 'rb') as f:
368
- tmp_file.write(f.read())
369
-
370
- # Add to list for cleanup later
371
- self.temp_files.append(tmp_path)
372
- cleanup_tmp = True
373
-
374
- # Open the image file
375
- with rasterio.open(tmp_path) as src:
376
- image = src.read()
377
- height, width = image.shape[1], image.shape[2]
378
- transform = src.transform
379
- crs = src.crs
380
-
381
- logger.info(f"Processing image: {height}x{width} pixels, {image.shape[0]} bands")
382
-
383
- # Validate minimum band count
384
- if image.shape[0] < 3:
385
- return None, f"Error: Image has only {image.shape[0]} bands. At least 3 bands are required for RGB visualization."
386
-
387
- # Generate all features using feature engineering
388
- logger.info("Generating all 99 features from bands...")
389
- feature_matrix, valid_mask, generated_features = extract_all_features(image)
390
-
391
- # Verify we have exactly 99 features
392
- if feature_matrix.shape[1] != 99:
393
- logger.error(f"Error: Generated {feature_matrix.shape[1]} features, but model expects 99.")
394
- return None, f"Error: Generated {feature_matrix.shape[1]} features, but model expects 99."
395
-
396
- # Apply feature scaling if available
397
  try:
398
- if 'scaler' in self.package and self.package['scaler'] is not None:
399
- logger.info("Applying feature scaling...")
400
- feature_matrix = self.package['scaler'].transform(feature_matrix)
401
- except Exception as e:
402
- logger.warning(f"Error applying scaler: {e}. Using original features.")
403
-
404
- # Initialize predictions array
405
- predictions = np.zeros((height, width), dtype=np.float32)
406
-
407
- # Get valid pixel coordinates
408
- valid_y, valid_x = np.where(valid_mask)
409
-
410
- # Make predictions
411
- logger.info(f"Running model inference on {len(valid_y)} valid pixels...")
412
- with torch.no_grad():
413
- # Process in batches to avoid memory issues
414
- batch_size = 10000
415
- for i in range(0, len(valid_y), batch_size):
416
- end_idx = min(i + batch_size, len(valid_y))
417
- batch = feature_matrix[i:end_idx]
418
-
419
- # Convert to tensor
420
- batch_tensor = torch.tensor(batch, dtype=torch.float32).to(self.device)
421
-
422
- # Get predictions
423
- batch_predictions = self.model(batch_tensor).cpu().numpy()
424
-
425
- # Handle scalar case for single-item batches
426
- if batch_predictions.ndim == 0:
427
- batch_predictions = np.array([batch_predictions])
428
-
429
- # Convert from log scale if needed
430
- if self.package.get('use_log_transform', True):
431
- epsilon = self.package.get('epsilon', 1.0)
432
- batch_predictions = np.exp(batch_predictions) - epsilon
433
- batch_predictions = np.maximum(batch_predictions, 0) # Ensure non-negative
434
-
435
- # Map predictions back to image
436
- for j, pred in enumerate(batch_predictions):
437
- y_idx = valid_y[i + j]
438
- x_idx = valid_x[i + j]
439
- predictions[y_idx, x_idx] = pred
440
-
441
- # Log progress
442
- if (i // batch_size) % 5 == 0 or end_idx == len(valid_y):
443
- logger.info(f"Processed {end_idx}/{len(valid_y)} pixels")
444
-
445
- # Create visualization - always RGB+Biomass side-by-side
446
- logger.info("Creating RGB + Biomass visualization...")
447
-
448
- # Create side-by-side comparison (RGB and Biomass)
449
- fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 8))
450
-
451
- # Prepare RGB image - try different band combinations if needed
452
- rgb_bands = [3, 2, 1] # Common RGB combination (R,G,B)
453
-
454
- # Check if we have enough bands for RGB
455
- if image.shape[0] < 3:
456
- logger.warning(f"Image has only {image.shape[0]} bands, using available bands for display")
457
- rgb_bands = list(range(min(3, image.shape[0])))
458
- while len(rgb_bands) < 3:
459
- rgb_bands.append(0) # Pad with zeros if needed
460
-
461
- # Create RGB image
462
- rgb = np.zeros((height, width, 3), dtype=np.float32)
463
- for i, band_idx in enumerate(rgb_bands):
464
- if band_idx < image.shape[0]:
465
- rgb[:, :, i] = image[band_idx]
466
-
467
- # Handle potential NaN values
468
- rgb = np.nan_to_num(rgb)
469
-
470
- # Enhance contrast with percentile-based normalization
471
- for i in range(3):
472
- p2 = np.percentile(rgb[:,:,i], 2)
473
- p98 = np.percentile(rgb[:,:,i], 98)
474
- if p98 > p2:
475
- rgb[:,:,i] = np.clip((rgb[:,:,i] - p2) / (p98 - p2), 0, 1)
476
-
477
- # Display RGB image
478
- ax1.imshow(rgb)
479
- ax1.set_title('RGB Image')
480
- ax1.axis('off')
481
-
482
- # Display biomass prediction
483
- masked_predictions = np.ma.masked_where(~valid_mask, predictions)
484
- vmin = np.percentile(predictions[valid_mask], 1)
485
- vmax = np.percentile(predictions[valid_mask], 99)
486
-
487
- im = ax2.imshow(masked_predictions, cmap='viridis', vmin=vmin, vmax=vmax)
488
- fig.colorbar(im, ax=ax2, label='Biomass (Mg/ha)')
489
- ax2.set_title('Predicted Biomass')
490
- ax2.axis('off')
491
-
492
- # Add super title
493
- plt.suptitle('RGB Image and Biomass Prediction', fontsize=16)
494
- plt.tight_layout()
495
-
496
- # Save figure to bytes buffer
497
- buf = io.BytesIO()
498
- fig.savefig(buf, format='png', dpi=150, bbox_inches='tight')
499
- buf.seek(0)
500
- plt.close(fig)
501
-
502
- # Calculate summary statistics
503
- valid_predictions = predictions[valid_mask]
504
- stats = {
505
- 'Mean Biomass': f"{np.mean(valid_predictions):.2f} Mg/ha",
506
- 'Median Biomass': f"{np.median(valid_predictions):.2f} Mg/ha",
507
- 'Min Biomass': f"{np.min(valid_predictions):.2f} Mg/ha",
508
- 'Max Biomass': f"{np.max(valid_predictions):.2f} Mg/ha"
509
- }
510
-
511
- # Add area and total biomass if transform is available
512
- if transform is not None:
513
- pixel_area_m2 = abs(transform[0] * transform[4]) # Assuming square pixels
514
- total_biomass = np.sum(valid_predictions) * (pixel_area_m2 / 10000) # Convert to hectares
515
- area_hectares = np.sum(valid_mask) * (pixel_area_m2 / 10000)
516
-
517
- stats['Total Biomass'] = f"{total_biomass:.2f} Mg"
518
- stats['Area'] = f"{area_hectares:.2f} hectares"
519
-
520
- # Format statistics as markdown
521
- stats_md = "### Biomass Statistics\n\n"
522
- stats_md += "| Metric | Value |\n|--------|-------|\n"
523
- for k, v in stats.items():
524
- stats_md += f"| {k} | {v} |\n"
525
-
526
- # Add processing info
527
- stats_md += f"\n\n*Processed {np.sum(valid_mask):,} valid pixels with {feature_matrix.shape[1]} features*"
528
-
529
- # Cleanup temporary files if needed
530
- if cleanup_tmp:
531
- self.cleanup()
532
-
533
- # Return visualization and statistics
534
- return Image.open(buf), stats_md
535
-
536
- except Exception as e:
537
- # Ensure cleanup even on error
538
- self.cleanup()
539
-
540
- import traceback
541
- logger.error(f"Error predicting biomass: {e}")
542
- logger.error(traceback.format_exc())
543
-
544
- return None, f"Error predicting biomass: {str(e)}\n\nPlease check logs for details."
545
 
546
- def create_interface(self):
547
- """Create Gradio interface with sample image thumbnails"""
548
- # Generate thumbnails for sample images
549
- sample_thumbnails = {}
550
- for name, path in self.sample_images.items():
551
- if os.path.exists(path):
552
- thumbnail = self.create_thumbnail(path)
553
- if thumbnail:
554
- sample_thumbnails[name] = Image.open(thumbnail)
 
 
 
 
555
  else:
556
- logger.warning(f"Sample image not found: {path}")
 
 
 
 
 
 
557
 
558
- with gr.Blocks(title="Biomass Prediction Model") as interface:
559
- gr.Markdown("# Above-Ground Biomass Prediction")
560
- gr.Markdown("""
561
- Upload a multi-band satellite image to predict above-ground biomass (AGB) across the landscape.
562
-
563
- **Requirements:**
564
- - Image must be a GeoTIFF with spectral bands
565
- - For best results, image should contain at least 3 bands
566
- """)
567
-
568
- with gr.Row():
569
- with gr.Column(scale=1):
570
- input_image = gr.File(
571
- label="Upload Satellite Image (GeoTIFF)",
572
- file_types=[".tif", ".tiff"]
573
- )
574
-
575
- # Sample images section
576
- gr.Markdown("### Sample Images")
577
-
578
- # Sample buttons container
579
- sample_buttons = []
580
-
581
- # First row - sample thumbnails side by side horizontally
582
- with gr.Row():
583
- for name, thumbnail in sample_thumbnails.items():
584
- with gr.Column():
585
- gr.Image(
586
- value=thumbnail,
587
- label=name.replace("input_", "Input ").replace("chip_", "Chip "),
588
- show_download_button=False,
589
- height=180
590
- )
591
-
592
- # Second row - buttons side by side horizontally, matching the thumbnails above
593
- with gr.Row():
594
- for name, _ in sample_thumbnails.items():
595
- with gr.Column():
596
- sample_btn = gr.Button(
597
- f"Use {name.replace('input_', 'Input ').replace('chip_', 'Chip ')}",
598
- variant="secondary",
599
- size="lg"
600
- )
601
- sample_buttons.append((sample_btn, name))
602
-
603
- # Generate button at the bottom
604
- generate_btn = gr.Button("Generate Biomass Prediction", variant="primary", size="lg")
605
-
606
- with gr.Column(scale=2):
607
- output_image = gr.Image(
608
- label="Biomass Prediction Map",
609
- type="pil"
610
- )
611
-
612
- output_stats = gr.Markdown(
613
- label="Statistics"
614
- )_image = gr.Image(
615
- label="Biomass Prediction Map",
616
- type="pil"
617
- )
618
-
619
- output_stats = gr.Markdown(
620
- label="Statistics"
621
- )
622
-
623
- # Sample images section with thumbnails in a separate row
624
- gr.Markdown("### Sample Images")
625
-
626
- with gr.Row():
627
- # Only show thumbnails for images that were found
628
- sample_buttons = []
629
-
630
- # Create a column for each sample image
631
- for name, thumbnail in sample_thumbnails.items():
632
- with gr.Column():
633
- gr.Image(value=thumbnail, label=name.replace("input_", "Input ").replace("chip_", "Chip "),
634
- show_download_button=False, show_label=True, height=200)
635
- sample_btn = gr.Button(f"Use {name.replace('input_', 'Input ').replace('chip_', 'Chip ')}",
636
- size="lg", variant="secondary")
637
- sample_buttons.append((sample_btn, name))
638
-
639
- with gr.Column(scale=2):
640
- output_image = gr.Image(
641
- label="Biomass Prediction Map",
642
- type="pil"
643
- )
644
-
645
- output_stats = gr.Markdown(
646
- label="Statistics"
647
- )
648
-
649
- with gr.Accordion("About", open=False):
650
- gr.Markdown("""
651
- ## About This Model
652
-
653
- This biomass prediction model uses the StableResNet architecture to predict above-ground biomass from satellite imagery.
654
-
655
- ### Model Details
656
-
657
- - Architecture: StableResNet
658
- - Input: Multi-spectral satellite imagery
659
- - Output: Above-ground biomass (Mg/ha)
660
- - Creator: vertify.earth for GIZ Forest Forward
661
- - Date: 2025-05-19
662
-
663
- ### How It Works
664
-
665
- 1. The model extracts features from each pixel in the satellite image
666
- 2. These features include spectral bands, vegetation indices, texture metrics, and more
667
- 3. The model outputs a biomass prediction for each pixel
668
- 4. Results are visualized as RGB and biomass prediction side-by-side
669
- """)
670
-
671
- # Add a warning if model failed to load
672
- if self.model is None:
673
- gr.Warning("⚠️ Model failed to load. The app may not work correctly. Check logs for details.")
674
-
675
- # Connect the process button
676
- process_btn.click(
677
- fn=self.predict_biomass,
678
- inputs=[input_image],
679
- outputs=[output_image, output_stats]
680
- )
681
-
682
- # Connect the sample buttons
683
- for button, name in sample_buttons:
684
- button.click(
685
- fn=lambda path=self.sample_images[name]: self.predict_biomass(path),
686
- inputs=[],
687
- outputs=[output_image, output_stats]
688
- )
689
 
690
- return interface
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
691
 
692
- def launch_app():
693
- """Launch the Gradio app"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
694
  try:
695
- # Create app instance
696
- app = BiomassPredictorApp()
 
 
 
697
 
698
- # Create interface
699
- interface = app.create_interface()
 
 
 
700
 
701
- # Launch interface
702
- interface.launch()
703
  except Exception as e:
704
- logger.error(f"Error launching app: {e}")
705
  import traceback
706
- logger.error(traceback.format_exc())
 
707
 
708
  if __name__ == "__main__":
709
- launch_app()
 
 
1
+ """
2
+ Feature engineering module for biomass prediction.
3
+ This module extracts the 99 features needed by the StableResNet model.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
 
 
 
5
  Author: najahpokkiri
6
  Date: 2025-05-19
 
 
7
  """
 
 
 
8
  import numpy as np
 
 
 
 
 
 
 
9
  import logging
10
+ from datetime import datetime
 
11
 
12
  # Configure logger
13
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
14
  logger = logging.getLogger(__name__)
15
 
16
+ # Try to import optional dependencies but don't fail if not available
17
+ try:
18
+ from sklearn.preprocessing import StandardScaler
19
+ from sklearn.decomposition import PCA
20
+ SKLEARN_AVAILABLE = True
21
+ except ImportError:
22
+ SKLEARN_AVAILABLE = False
23
+ logger.warning("scikit-learn not available. PCA features will be approximated.")
24
 
25
+ try:
26
+ from skimage.filters import sobel
27
+ from skimage.feature import local_binary_pattern, graycomatrix, graycoprops
28
+ SKIMAGE_AVAILABLE = True
29
+ except ImportError:
30
+ SKIMAGE_AVAILABLE = False
31
+ logger.warning("scikit-image not available. Texture features will be approximated.")
32
+
33
+ def safe_divide(a, b, fill_value=0.0):
34
+ """Safe division that handles zeros in the denominator"""
35
+ a = np.asarray(a, dtype=np.float32)
36
+ b = np.asarray(b, dtype=np.float32)
 
 
 
37
 
38
+ # Handle NaN/Inf in inputs
39
+ a = np.nan_to_num(a, nan=0.0, posinf=0.0, neginf=0.0)
40
+ b = np.nan_to_num(b, nan=1e-10, posinf=1e10, neginf=-1e10)
41
 
42
+ mask = np.abs(b) < 1e-10
43
+ result = np.full_like(a, fill_value, dtype=np.float32)
44
+ if np.any(~mask):
45
+ result[~mask] = a[~mask] / b[~mask]
46
 
47
+ return np.nan_to_num(result, nan=fill_value, posinf=fill_value, neginf=fill_value)
48
+
49
+ def calculate_spectral_indices(satellite_data):
50
+ """Calculate spectral indices from satellite bands"""
51
+ indices = {}
52
+ n_bands = satellite_data.shape[0]
53
 
54
+ # Enhanced band mapping with error checking
55
+ def safe_get_band(idx):
56
+ return satellite_data[idx] if idx < n_bands else None
57
 
58
+ # Sentinel-2 bands (assuming standard band order)
59
+ # B2(blue), B3(green), B4(red), B8(nir), B11(swir1), B12(swir2)
60
+ try:
61
+ blue = safe_get_band(1) # Adjust indices based on your data
62
+ green = safe_get_band(2)
63
+ red = safe_get_band(3)
64
+ nir = safe_get_band(7)
65
+ swir1 = safe_get_band(9)
66
+ swir2 = safe_get_band(10)
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
+ if all(b is not None for b in [red, nir]):
69
+ # NDVI (Normalized Difference Vegetation Index)
70
+ indices['NDVI'] = safe_divide(nir - red, nir + red)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
+ if blue is not None and green is not None:
73
+ # EVI (Enhanced Vegetation Index)
74
+ indices['EVI'] = 2.5 * safe_divide(nir - red, nir + 6*red - 7.5*blue + 1)
 
 
75
 
76
+ # SAVI (Soil Adjusted Vegetation Index)
77
+ indices['SAVI'] = 1.5 * safe_divide(nir - red, nir + red + 0.5)
 
78
 
79
+ # MSAVI2 (Modified Soil Adjusted Vegetation Index)
80
+ indices['MSAVI2'] = 0.5 * (2 * nir + 1 - np.sqrt((2 * nir + 1)**2 - 8 * (nir - red)))
 
 
 
 
 
81
 
82
+ # NDWI (Normalized Difference Water Index)
83
+ indices['NDWI'] = safe_divide(green - nir, green + nir)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
+ if swir1 is not None and nir is not None:
86
+ # NDMI (Normalized Difference Moisture Index)
87
+ indices['NDMI'] = safe_divide(nir - swir1, nir + swir1)
88
+
89
+ if swir2 is not None and nir is not None:
90
+ # NBR (Normalized Burn Ratio)
91
+ indices['NBR'] = safe_divide(nir - swir2, nir + swir2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
 
93
+ except Exception as e:
94
+ logger.warning(f"Error calculating spectral indices: {e}")
95
+
96
+ # Clean up None values and NaNs
97
+ indices = {k: np.nan_to_num(v, nan=0.0) for k, v in indices.items() if v is not None}
98
+
99
+ # Ensure we have all required indices by providing defaults
100
+ required_indices = ['NDVI', 'EVI', 'SAVI', 'MSAVI2', 'NDWI', 'NDMI', 'NBR']
101
+ for idx in required_indices:
102
+ if idx not in indices:
103
+ if satellite_data.shape[1] > 0 and satellite_data.shape[2] > 0:
104
+ indices[idx] = np.zeros((satellite_data.shape[1], satellite_data.shape[2]), dtype=np.float32)
105
+
106
+ return indices
107
+
108
+ def extract_texture_features(satellite_data):
109
+ """Extract texture features from satellite data"""
110
+ texture_features = {}
111
+ height, width = satellite_data.shape[1], satellite_data.shape[2]
112
+
113
+ # If scikit-image is not available, return placeholders
114
+ if not SKIMAGE_AVAILABLE:
115
+ texture_names = ['Sobel_B7', 'LBP_B7', 'GLCM_contrast_B7', 'GLCM_dissimilarity_B7',
116
+ 'GLCM_homogeneity_B7', 'GLCM_energy_B7']
117
+ for name in texture_names:
118
+ texture_features[name] = np.zeros((height, width), dtype=np.float32)
119
+ return texture_features
120
+
121
+ try:
122
+ # Use NIR band (band 7) for texture features
123
+ b7_idx = min(7, satellite_data.shape[0] - 1)
124
+ band = satellite_data[b7_idx].copy()
125
+ band = np.nan_to_num(band, nan=0.0)
126
 
127
+ # 1. Sobel filter for edge detection
128
+ sobel_filtered = sobel(band)
129
+ texture_features['Sobel_B7'] = sobel_filtered
130
 
131
+ # 2. Local Binary Pattern
132
+ # Normalize band to 0-255 range for LBP
133
+ band_norm = band.copy()
134
+ if np.any(~np.isnan(band)):
135
+ band_min, band_max = np.nanpercentile(band, [1, 99])
136
+ if band_max > band_min:
137
+ band_norm = np.clip((band - band_min) / (band_max - band_min + 1e-8) * 255, 0, 255).astype(np.uint8)
138
+ else:
139
+ band_norm = np.zeros_like(band, dtype=np.uint8)
140
+
141
+ # Calculate LBP
142
+ lbp = local_binary_pattern(band_norm, 8, 1, method='uniform')
143
+ texture_features['LBP_B7'] = lbp
144
+
145
+ # 3. GLCM properties
146
+ # Create sample patch for GLCM calculation
147
+ sample_size = min(128, height, width)
148
+ center_y, center_x = height // 2, width // 2
149
+ offset = sample_size // 2
150
+ y_start = max(0, center_y - offset)
151
+ y_end = min(height, center_y + offset)
152
+ x_start = max(0, center_x - offset)
153
+ x_end = min(width, center_x + offset)
154
+ patch = band_norm[y_start:y_end, x_start:x_end]
155
+
156
+ # Calculate GLCM properties if patch is valid
157
+ if patch.size > 0:
158
+ glcm = graycomatrix(patch, [1], [0], levels=256, symmetric=True, normed=True)
159
+ for prop in ['contrast', 'dissimilarity', 'homogeneity', 'energy']:
 
 
 
 
 
 
 
 
 
 
 
160
  try:
161
+ value = float(graycoprops(glcm, prop)[0, 0])
162
+ texture_features[f'GLCM_{prop}_B7'] = np.full((height, width), value)
163
+ except:
164
+ texture_features[f'GLCM_{prop}_B7'] = np.zeros((height, width), dtype=np.float32)
165
+ else:
166
+ # Create placeholder GLCM features if patch is invalid
167
+ for prop in ['contrast', 'dissimilarity', 'homogeneity', 'energy']:
168
+ texture_features[f'GLCM_{prop}_B7'] = np.zeros((height, width), dtype=np.float32)
169
+
170
+ except Exception as e:
171
+ logger.error(f"Error in texture feature extraction: {e}")
172
+ # Provide placeholder features in case of error
173
+ texture_names = ['Sobel_B7', 'LBP_B7', 'GLCM_contrast_B7', 'GLCM_dissimilarity_B7',
174
+ 'GLCM_homogeneity_B7', 'GLCM_energy_B7']
175
+ for name in texture_names:
176
+ texture_features[name] = np.zeros((height, width), dtype=np.float32)
177
+
178
+ return texture_features
179
+
180
+ def calculate_spatial_features(satellite_data, indices):
181
+ """Calculate spatial context features like gradients"""
182
+ spatial_features = {}
183
+ height, width = satellite_data.shape[1], satellite_data.shape[2]
184
+
185
+ # 1. Gradient of Band 7 (NIR)
186
+ b7_idx = min(7, satellite_data.shape[0] - 1)
187
+ band = satellite_data[b7_idx].copy()
188
+ band = np.nan_to_num(band, nan=0.0)
189
+
190
+ try:
191
+ # Calculate the gradient magnitude
192
+ grad_y, grad_x = np.gradient(band)
193
+ grad_magnitude = np.sqrt(grad_x**2 + grad_y**2)
194
+ spatial_features['Gradient_B7'] = grad_magnitude
195
+ except Exception as e:
196
+ logger.warning(f"Error calculating band gradient: {e}")
197
+ spatial_features['Gradient_B7'] = np.zeros((height, width), dtype=np.float32)
198
+
199
+ # 2. NDVI gradient
200
+ try:
201
+ ndvi = indices.get('NDVI', np.zeros((height, width), dtype=np.float32))
202
+ ndvi = np.nan_to_num(ndvi, nan=0.0)
203
+
204
+ # Calculate the gradient magnitude for NDVI
205
+ grad_y, grad_x = np.gradient(ndvi)
206
+ grad_magnitude = np.sqrt(grad_x**2 + grad_y**2)
207
+ spatial_features['NDVI_gradient'] = grad_magnitude
208
+ except Exception as e:
209
+ logger.warning(f"Error calculating NDVI gradient: {e}")
210
+ spatial_features['NDVI_gradient'] = np.zeros((height, width), dtype=np.float32)
211
+
212
+ return spatial_features
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
213
 
214
+ def calculate_pca_features(satellite_data, n_components=25):
215
+ """Calculate PCA features from satellite bands"""
216
+ pca_features = {}
217
+ height, width = satellite_data.shape[1], satellite_data.shape[2]
218
+ n_bands = satellite_data.shape[0]
219
+
220
+ # If scikit-learn is not available, return placeholders
221
+ if not SKLEARN_AVAILABLE:
222
+ for i in range(1, n_components + 1):
223
+ # Create some basic derived features as placeholders
224
+ if i <= n_bands:
225
+ # Use band values directly for first components
226
+ pca_features[f'PCA_{i:02d}'] = satellite_data[i-1]
227
  else:
228
+ # Create synthetic features for remaining components
229
+ pca_features[f'PCA_{i:02d}'] = np.zeros((height, width), dtype=np.float32)
230
+ return pca_features
231
+
232
+ try:
233
+ # Reshape for PCA (pixels x bands)
234
+ bands_reshaped = satellite_data.reshape(n_bands, -1).T
235
 
236
+ # Handle NaN values
237
+ valid_mask = ~np.any(np.isnan(bands_reshaped), axis=1)
238
+ bands_clean = bands_reshaped[valid_mask]
239
+
240
+ if len(bands_clean) == 0:
241
+ logger.warning("No valid data for PCA calculation")
242
+ # Create placeholder PCA features
243
+ for i in range(1, n_components + 1):
244
+ pca_features[f'PCA_{i:02d}'] = np.zeros((height, width), dtype=np.float32)
245
+ return pca_features
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
246
 
247
+ # Standardize valid data
248
+ scaler = StandardScaler()
249
+ bands_scaled = scaler.fit_transform(bands_clean)
250
+
251
+ # Calculate PCA
252
+ pca = PCA(n_components=min(n_components, bands_scaled.shape[1], bands_scaled.shape[0]))
253
+ pca_result = pca.fit_transform(bands_scaled)
254
+
255
+ # Extend to full 25 components if needed
256
+ actual_components = pca_result.shape[1]
257
+ if actual_components < n_components:
258
+ logger.warning(f"Only {actual_components} PCA components calculated, padding to {n_components}")
259
+ padding = np.zeros((pca_result.shape[0], n_components - actual_components))
260
+ pca_result = np.hstack([pca_result, padding])
261
+
262
+ # Map back to original pixels
263
+ pca_all = np.zeros((bands_reshaped.shape[0], n_components))
264
+ pca_all[valid_mask] = pca_result
265
+
266
+ # Reshape to spatial dimensions
267
+ pca_spatial = pca_all.reshape(height, width, n_components)
268
+
269
+ # Store each component with the correct naming
270
+ for i in range(1, n_components + 1):
271
+ pca_features[f'PCA_{i:02d}'] = pca_spatial[:, :, i-1]
272
+
273
+ # Log PCA explained variance
274
+ if hasattr(pca, 'explained_variance_ratio_'):
275
+ logger.info(f"PCA explained variance: {pca.explained_variance_ratio_.sum():.3f}")
276
+
277
+ except Exception as e:
278
+ logger.error(f"Error calculating PCA features: {e}")
279
+ # Create placeholder PCA features
280
+ for i in range(1, n_components + 1):
281
+ pca_features[f'PCA_{i:02d}'] = np.zeros((height, width), dtype=np.float32)
282
+
283
+ return pca_features
284
 
285
+ def extract_all_features(satellite_data):
286
+ """
287
+ Extract exactly 99 features needed by the model:
288
+ - 59 original bands
289
+ - 7 spectral indices
290
+ - 6 texture features
291
+ - 2 spatial features
292
+ - 25 PCA components
293
+
294
+ Parameters:
295
+ satellite_data (ndarray): Array of shape (bands, height, width)
296
+
297
+ Returns:
298
+ features_array (ndarray): Array of shape (valid_pixels, 99)
299
+ valid_mask (ndarray): Boolean mask of valid pixels
300
+ feature_names (list): List of 99 feature names
301
+ """
302
+ start_time = datetime.now()
303
+ logger.info("Extracting features for biomass prediction...")
304
+ height, width = satellite_data.shape[1], satellite_data.shape[2]
305
+
306
+ # Create valid pixel mask (no NaN or Inf values)
307
+ valid_mask = np.all(np.isfinite(satellite_data), axis=0)
308
+ valid_y, valid_x = np.where(valid_mask)
309
+ n_valid = len(valid_y)
310
+
311
+ logger.info(f"Found {n_valid} valid pixels out of {height*width}")
312
+
313
+ # Generate all feature categories
314
+ logger.info("Calculating spectral indices...")
315
+ indices = calculate_spectral_indices(satellite_data)
316
+
317
+ logger.info("Extracting texture features...")
318
+ texture_features = extract_texture_features(satellite_data)
319
+
320
+ logger.info("Calculating spatial features...")
321
+ spatial_features = calculate_spatial_features(satellite_data, indices)
322
+
323
+ logger.info("Computing PCA components...")
324
+ pca_features = calculate_pca_features(satellite_data)
325
+
326
+ # Define the ordered list of feature names
327
+ feature_names = []
328
+
329
+ # 1. Add original band names (Band_01 through Band_59)
330
+ for i in range(1, 60):
331
+ feature_names.append(f'Band_{i:02d}')
332
+
333
+ # 2. Add spectral indices
334
+ spectral_indices = ['NDVI', 'EVI', 'SAVI', 'MSAVI2', 'NDWI', 'NDMI', 'NBR']
335
+ feature_names.extend(spectral_indices)
336
+
337
+ # 3. Add texture features
338
+ texture_names = ['Sobel_B7', 'LBP_B7', 'GLCM_contrast_B7', 'GLCM_dissimilarity_B7',
339
+ 'GLCM_homogeneity_B7', 'GLCM_energy_B7']
340
+ feature_names.extend(texture_names)
341
+
342
+ # 4. Add spatial features
343
+ spatial_names = ['Gradient_B7', 'NDVI_gradient']
344
+ feature_names.extend(spatial_names)
345
+
346
+ # 5. Add PCA components
347
+ for i in range(1, 26):
348
+ feature_names.append(f'PCA_{i:02d}')
349
+
350
+ # Create feature dictionary with all features
351
+ all_features = {}
352
+
353
+ # 1. Original bands
354
+ for i in range(min(satellite_data.shape[0], 59)):
355
+ all_features[f'Band_{i+1:02d}'] = satellite_data[i]
356
+
357
+ # Pad with zeros if we have fewer than 59 bands
358
+ for i in range(satellite_data.shape[0], 59):
359
+ all_features[f'Band_{i+1:02d}'] = np.zeros((height, width), dtype=np.float32)
360
+
361
+ # 2. Add other feature categories
362
+ all_features.update(indices)
363
+ all_features.update(texture_features)
364
+ all_features.update(spatial_features)
365
+ all_features.update(pca_features)
366
+
367
+ # Verify we have exactly 99 features
368
+ assert len(feature_names) == 99, f"Expected 99 features, but got {len(feature_names)}"
369
+
370
+ # Extract feature values for valid pixels
371
+ feature_matrix = np.zeros((n_valid, len(feature_names)), dtype=np.float32)
372
+
373
+ for i, name in enumerate(feature_names):
374
+ if name in all_features:
375
+ feature_data = all_features[name]
376
+ if feature_data.ndim == 2:
377
+ feature_values = feature_data[valid_y, valid_x]
378
+ else:
379
+ feature_values = np.full(n_valid, feature_data)
380
+ feature_matrix[:, i] = np.nan_to_num(feature_values, nan=0.0)
381
+ else:
382
+ logger.warning(f"Feature '{name}' not found, using zeros")
383
+ feature_matrix[:, i] = 0.0
384
+
385
+ end_time = datetime.now()
386
+ processing_time = (end_time - start_time).total_seconds()
387
+ logger.info(f"Successfully extracted {len(feature_names)} features for {n_valid} pixels in {processing_time:.2f} seconds")
388
+
389
+ return feature_matrix, valid_mask, feature_names
390
+
391
+ # Simple test function
392
+ def test_feature_extraction():
393
+ """Test the feature extraction pipeline with sample data"""
394
  try:
395
+ # Create sample data (5 bands, 100x100 pixels)
396
+ satellite_data = np.random.random((5, 100, 100)).astype(np.float32)
397
+
398
+ # Extract features
399
+ feature_matrix, valid_mask, feature_names = extract_all_features(satellite_data)
400
 
401
+ # Print summary
402
+ print(f"Sample data shape: {satellite_data.shape}")
403
+ print(f"Feature matrix shape: {feature_matrix.shape}")
404
+ print(f"Number of feature names: {len(feature_names)}")
405
+ print(f"Valid pixels: {np.sum(valid_mask)}")
406
 
407
+ return True
 
408
  except Exception as e:
409
+ print(f"Feature extraction test failed: {e}")
410
  import traceback
411
+ traceback.print_exc()
412
+ return False
413
 
414
  if __name__ == "__main__":
415
+ # Run a simple test if this script is executed directly
416
+ test_feature_extraction()