pokkiri commited on
Commit
3a50655
·
verified ·
1 Parent(s): 2448252

Upload folder using huggingface_hub

Browse files
Files changed (4) hide show
  1. README.md +51 -5
  2. app.py +285 -0
  3. model.py +67 -0
  4. requirements.txt +8 -0
README.md CHANGED
@@ -1,12 +1,58 @@
1
  ---
2
- title: Biomass App
3
- emoji: 📉
4
- colorFrom: red
5
  colorTo: blue
6
  sdk: gradio
7
- sdk_version: 5.29.1
8
  app_file: app.py
9
  pinned: false
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: Biomass Prediction App
3
+ emoji: 🌳
4
+ colorFrom: green
5
  colorTo: blue
6
  sdk: gradio
7
+ sdk_version: 3.40.1
8
  app_file: app.py
9
  pinned: false
10
  ---
11
 
12
+ # Biomass Prediction App
13
+
14
+ This Gradio app demonstrates biomass prediction from satellite imagery using a StableResNet model.
15
+
16
+ ## Features
17
+
18
+ - Upload a multi-band satellite image (GeoTIFF format)
19
+ - View predicted biomass as a heatmap or RGB overlay
20
+ - Get statistics on predicted biomass values
21
+
22
+ ## Usage
23
+
24
+ 1. Upload a multi-band satellite image (GeoTIFF)
25
+ 2. Choose display type (heatmap or RGB overlay)
26
+ 3. Click "Generate Biomass Prediction"
27
+ 4. View the prediction map and statistics
28
+
29
+ ## Model Information
30
+
31
+ - **Created by:** pokkiri
32
+ - **Date:** 2025-05-17
33
+ - **Architecture:** StableResNet
34
+ - **Model Repository:** [pokkiri/biomass-model](https://huggingface.co/pokkiri/biomass-model)
35
+ - **Input:** Multi-spectral satellite imagery
36
+ - **Output:** Above-ground biomass (Mg/ha)
37
+
38
+ ## Requirements
39
+
40
+ - GeoTIFF image file with multiple spectral bands
41
+ - Image bands should match those used during model training
42
+
43
+ ## How It Works
44
+
45
+ The app connects to the biomass prediction model hosted on HuggingFace Hub. When you upload a satellite image:
46
+
47
+ 1. The app loads your GeoTIFF file
48
+ 2. For each pixel, it extracts the spectral values
49
+ 3. These values are processed through the StableResNet model
50
+ 4. The model predicts biomass values for each pixel
51
+ 5. Results are visualized on a map with summary statistics
52
+
53
+ ## Citation
54
+
55
+ ```
56
+ @misc{biomass_app, author = {pokkiri}, title = {Biomass Prediction App}, year = {2025}, publisher = {HuggingFace Spaces}, howpublished = {https://huggingface.co/spaces/pokkiri/biomass-app} }
57
+
58
+ ```
app.py ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gradio App for Biomass Prediction
3
+ Provides a web interface for making predictions with StableResNet
4
+
5
+ Author: najahpokkiri
6
+ Date: 2025-05-17
7
+ """
8
+ import os
9
+ import torch
10
+ import numpy as np
11
+ import gradio as gr
12
+ import joblib
13
+ import tempfile
14
+ import matplotlib.pyplot as plt
15
+ import matplotlib.colors as colors
16
+ from PIL import Image
17
+ import io
18
+ from huggingface_hub import hf_hub_download
19
+
20
+ # Import model architecture
21
+ from model import StableResNet
22
+
23
+ class BiomassPredictorApp:
24
+ """Gradio app for biomass prediction"""
25
+
26
+ def __init__(self, model_repo="pokkiri/biomass-model"):
27
+ self.model = None
28
+ self.package = None
29
+ self.model_repo = model_repo
30
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
31
+
32
+ # Load the model
33
+ self.load_model()
34
+
35
+ def load_model(self):
36
+ """Load the model and preprocessing pipeline from HuggingFace Hub"""
37
+ try:
38
+ # Download files from HuggingFace
39
+ model_path = hf_hub_download(repo_id=self.model_repo, filename="model.pt")
40
+ package_path = hf_hub_download(repo_id=self.model_repo, filename="model_package.pkl")
41
+
42
+ # Load package with metadata
43
+ self.package = joblib.load(package_path)
44
+ n_features = self.package['n_features']
45
+
46
+ # Initialize model
47
+ self.model = StableResNet(n_features=n_features)
48
+ self.model.load_state_dict(torch.load(model_path, map_location=self.device))
49
+ self.model.to(self.device)
50
+ self.model.eval()
51
+
52
+ print(f"Model loaded successfully from {self.model_repo}")
53
+ print(f"Number of features: {n_features}")
54
+ print(f"Using device: {self.device}")
55
+
56
+ return True
57
+ except Exception as e:
58
+ print(f"Error loading model: {e}")
59
+ return False
60
+
61
+ def predict_biomass(self, image_file, display_type="heatmap"):
62
+ """Predict biomass from a satellite image"""
63
+ try:
64
+ # Create a temporary file to save the uploaded file
65
+ with tempfile.NamedTemporaryFile(suffix='.tif', delete=False) as tmp_file:
66
+ tmp_path = tmp_file.name
67
+ with open(image_file.name, 'rb') as f:
68
+ tmp_file.write(f.read())
69
+
70
+ try:
71
+ import rasterio
72
+ except ImportError:
73
+ return None, "Error: rasterio is required but not installed."
74
+
75
+ # Open the image file
76
+ with rasterio.open(tmp_path) as src:
77
+ image = src.read()
78
+ height, width = image.shape[1], image.shape[2]
79
+ transform = src.transform
80
+ crs = src.crs
81
+
82
+ # Check if number of bands matches expected features
83
+ if image.shape[0] < self.package['n_features']:
84
+ return None, f"Error: Image has {image.shape[0]} bands, but model expects at least {self.package['n_features']} features."
85
+
86
+ print(f"Processing image: {height}x{width} pixels, {image.shape[0]} bands")
87
+
88
+ # Process in chunks to avoid memory issues
89
+ chunk_size = 1000
90
+ predictions = np.zeros((height, width), dtype=np.float32)
91
+
92
+ # Create mask for valid pixels (not NaN or Inf)
93
+ valid_mask = np.all(np.isfinite(image), axis=0)
94
+
95
+ # Process image in chunks
96
+ for y_start in range(0, height, chunk_size):
97
+ y_end = min(y_start + chunk_size, height)
98
+
99
+ for x_start in range(0, width, chunk_size):
100
+ x_end = min(x_start + chunk_size, width)
101
+
102
+ # Get chunk mask
103
+ chunk_mask = valid_mask[y_start:y_end, x_start:x_end]
104
+ if not np.any(chunk_mask):
105
+ continue
106
+
107
+ # Extract valid pixels
108
+ valid_y, valid_x = np.where(chunk_mask)
109
+
110
+ # Extract features for valid pixels
111
+ pixel_features = []
112
+ for i, j in zip(valid_y, valid_x):
113
+ # Extract bands
114
+ pixel_values = image[:, y_start+i, x_start+j]
115
+ pixel_features.append(pixel_values)
116
+
117
+ # Convert to array and scale features
118
+ pixel_features = np.array(pixel_features)
119
+ pixel_features_scaled = self.package['scaler'].transform(pixel_features)
120
+
121
+ # Make predictions
122
+ with torch.no_grad():
123
+ batch_tensor = torch.tensor(pixel_features_scaled, dtype=torch.float32).to(self.device)
124
+ batch_predictions = self.model(batch_tensor).cpu().numpy()
125
+
126
+ # Convert from log scale if needed
127
+ if self.package['use_log_transform']:
128
+ batch_predictions = np.exp(batch_predictions) - self.package.get('epsilon', 1.0)
129
+ batch_predictions = np.maximum(batch_predictions, 0) # Ensure non-negative
130
+
131
+ # Insert predictions back into the image
132
+ for idx, (i, j) in enumerate(zip(valid_y, valid_x)):
133
+ predictions[y_start+i, x_start+j] = batch_predictions[idx]
134
+
135
+ # Delete temporary file
136
+ os.unlink(tmp_path)
137
+
138
+ # Create visualization
139
+ plt.figure(figsize=(12, 8))
140
+
141
+ if display_type == "heatmap":
142
+ # Create heatmap
143
+ plt.imshow(predictions, cmap='viridis')
144
+ plt.colorbar(label='Biomass (Mg/ha)')
145
+ plt.title('Predicted Above-Ground Biomass')
146
+
147
+ elif display_type == "rgb_overlay":
148
+ # Create RGB + overlay
149
+ if image.shape[0] >= 3:
150
+ # Use first 3 bands as RGB
151
+ rgb = image[[0, 1, 2]].transpose(1, 2, 0)
152
+ rgb = np.clip((rgb - np.percentile(rgb, 2)) / (np.percentile(rgb, 98) - np.percentile(rgb, 2)), 0, 1)
153
+
154
+ plt.imshow(rgb)
155
+
156
+ # Create mask for overlay (where we have predictions)
157
+ mask = ~np.isclose(predictions, 0)
158
+ overlay = np.zeros((height, width, 4))
159
+
160
+ # Create colormap for biomass
161
+ norm = colors.Normalize(vmin=np.percentile(predictions[mask], 5),
162
+ vmax=np.percentile(predictions[mask], 95))
163
+ cmap = plt.cm.viridis
164
+
165
+ # Apply colormap
166
+ overlay[..., :3] = cmap(norm(predictions))[..., :3]
167
+ overlay[..., 3] = np.where(mask, 0.7, 0) # Set alpha channel
168
+
169
+ plt.imshow(overlay)
170
+ plt.colorbar(plt.cm.ScalarMappable(norm=norm, cmap=cmap),
171
+ label='Biomass (Mg/ha)')
172
+ plt.title('Biomass Prediction Overlay')
173
+ else:
174
+ plt.imshow(predictions, cmap='viridis')
175
+ plt.colorbar(label='Biomass (Mg/ha)')
176
+ plt.title('Predicted Above-Ground Biomass')
177
+
178
+ # Save figure to bytes buffer
179
+ buf = io.BytesIO()
180
+ plt.savefig(buf, format='png', dpi=150, bbox_inches='tight')
181
+ buf.seek(0)
182
+
183
+ # Create summary statistics
184
+ valid_predictions = predictions[valid_mask]
185
+ stats = {
186
+ 'Mean Biomass': f"{np.mean(valid_predictions):.2f} Mg/ha",
187
+ 'Median Biomass': f"{np.median(valid_predictions):.2f} Mg/ha",
188
+ 'Min Biomass': f"{np.min(valid_predictions):.2f} Mg/ha",
189
+ 'Max Biomass': f"{np.max(valid_predictions):.2f} Mg/ha",
190
+ 'Total Biomass': f"{np.sum(valid_predictions) * (transform[0] * transform[0]) / 10000:.2f} Mg",
191
+ 'Area': f"{np.sum(valid_mask) * (transform[0] * transform[0]) / 10000:.2f} hectares"
192
+ }
193
+
194
+ # Format statistics as markdown
195
+ stats_md = "### Biomass Statistics\n\n"
196
+ stats_md += "| Metric | Value |\n|--------|-------|\n"
197
+ for k, v in stats.items():
198
+ stats_md += f"| {k} | {v} |\n"
199
+
200
+ # Close the plot
201
+ plt.close()
202
+
203
+ # Return visualization and statistics
204
+ return Image.open(buf), stats_md
205
+
206
+ except Exception as e:
207
+ import traceback
208
+ return None, f"Error predicting biomass: {str(e)}\n\n{traceback.format_exc()}"
209
+
210
+ def create_interface(self):
211
+ """Create Gradio interface"""
212
+ with gr.Blocks(title="Biomass Prediction Model") as interface:
213
+ gr.Markdown("# Above-Ground Biomass Prediction")
214
+ gr.Markdown("""
215
+ Upload a multi-band satellite image to predict above-ground biomass (AGB) across the landscape.
216
+
217
+ **Requirements:**
218
+ - Image must be a GeoTIFF with spectral bands
219
+ - For best results, image should contain similar bands to those used in training
220
+ """)
221
+
222
+ with gr.Row():
223
+ with gr.Column():
224
+ input_image = gr.File(
225
+ label="Upload Satellite Image (GeoTIFF)",
226
+ file_types=[".tif", ".tiff"]
227
+ )
228
+
229
+ display_type = gr.Radio(
230
+ choices=["heatmap", "rgb_overlay"],
231
+ value="heatmap",
232
+ label="Display Type"
233
+ )
234
+
235
+ submit_btn = gr.Button("Generate Biomass Prediction")
236
+
237
+ with gr.Column():
238
+ output_image = gr.Image(
239
+ label="Biomass Prediction Map",
240
+ type="pil"
241
+ )
242
+
243
+ output_stats = gr.Markdown(
244
+ label="Statistics"
245
+ )
246
+
247
+ with gr.Accordion("About", open=False):
248
+ gr.Markdown(f"""
249
+ ## About This Model
250
+
251
+ This biomass prediction model uses the StableResNet architecture to predict above-ground biomass from satellite imagery.
252
+
253
+ ### Model Details
254
+
255
+ - Architecture: StableResNet
256
+ - Input: Multi-spectral satellite imagery
257
+ - Output: Above-ground biomass (Mg/ha)
258
+ - Creator: {pokkiri}
259
+ - Date: {2025-05-17}
260
+ - Model Repository: [{pokkiri/biomass-model}](https://huggingface.co/{pokkiri/biomass-model})
261
+
262
+ ### How It Works
263
+
264
+ 1. The model extracts features from each pixel in the satellite image
265
+ 2. These features are processed through the StableResNet model
266
+ 3. The model outputs a biomass prediction for each pixel
267
+ 4. Results are visualized as a heatmap or RGB overlay
268
+ """)
269
+
270
+ submit_btn.click(
271
+ fn=self.predict_biomass,
272
+ inputs=[input_image, display_type],
273
+ outputs=[output_image, output_stats]
274
+ )
275
+
276
+ return interface
277
+
278
+ def launch_app():
279
+ """Launch the Gradio app"""
280
+ app = BiomassPredictorApp()
281
+ interface = app.create_interface()
282
+ interface.launch()
283
+
284
+ if __name__ == "__main__":
285
+ launch_app()
model.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ StableResNet Model for Biomass Prediction
3
+ A numerically stable ResNet architecture for regression tasks
4
+
5
+ Author: najahpokkiri
6
+ Date: 2025-05-17
7
+ """
8
+ import torch
9
+ import torch.nn as nn
10
+
11
+ class StableResNet(nn.Module):
12
+ """Numerically stable ResNet for biomass regression"""
13
+ def __init__(self, n_features, dropout=0.2):
14
+ super().__init__()
15
+
16
+ self.input_proj = nn.Sequential(
17
+ nn.Linear(n_features, 256),
18
+ nn.LayerNorm(256),
19
+ nn.ReLU(),
20
+ nn.Dropout(dropout)
21
+ )
22
+
23
+ self.layer1 = self._make_simple_resblock(256, 256)
24
+ self.layer2 = self._make_simple_resblock(256, 128)
25
+ self.layer3 = self._make_simple_resblock(128, 64)
26
+
27
+ self.regressor = nn.Sequential(
28
+ nn.Linear(64, 32),
29
+ nn.ReLU(),
30
+ nn.Linear(32, 1)
31
+ )
32
+
33
+ self._init_weights()
34
+
35
+ def _make_simple_resblock(self, in_dim, out_dim):
36
+ return nn.Sequential(
37
+ nn.Linear(in_dim, out_dim),
38
+ nn.BatchNorm1d(out_dim),
39
+ nn.ReLU(),
40
+ nn.Linear(out_dim, out_dim),
41
+ nn.BatchNorm1d(out_dim),
42
+ nn.ReLU()
43
+ ) if in_dim == out_dim else nn.Sequential(
44
+ nn.Linear(in_dim, out_dim),
45
+ nn.BatchNorm1d(out_dim),
46
+ nn.ReLU(),
47
+ )
48
+
49
+ def _init_weights(self):
50
+ for m in self.modules():
51
+ if isinstance(m, nn.Linear):
52
+ nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')
53
+ if m.bias is not None:
54
+ nn.init.zeros_(m.bias)
55
+
56
+ def forward(self, x):
57
+ x = self.input_proj(x)
58
+
59
+ identity = x
60
+ out = self.layer1(x)
61
+ x = out + identity
62
+
63
+ x = self.layer2(x)
64
+ x = self.layer3(x)
65
+
66
+ x = self.regressor(x)
67
+ return x.squeeze()
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ torch>=1.10.0
2
+ numpy>=1.20.0
3
+ joblib>=1.1.0
4
+ rasterio>=1.2.0
5
+ huggingface_hub>=0.10.0
6
+ matplotlib>=3.5.0
7
+ gradio>=3.0.0
8
+ pillow>=8.0.0