Spaces:
Running
Running
Update app.py
Browse filesfixed app module
app.py
CHANGED
@@ -5,7 +5,16 @@ Provides a web interface for making predictions with StableResNet
|
|
5 |
Author: najahpokkiri
|
6 |
Date: 2025-05-17
|
7 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
import os
|
|
|
9 |
import torch
|
10 |
import numpy as np
|
11 |
import gradio as gr
|
@@ -15,33 +24,48 @@ import matplotlib.pyplot as plt
|
|
15 |
import matplotlib.colors as colors
|
16 |
from PIL import Image
|
17 |
import io
|
|
|
18 |
from huggingface_hub import hf_hub_download
|
19 |
|
|
|
|
|
|
|
|
|
20 |
# Import model architecture
|
21 |
from model import StableResNet
|
22 |
|
23 |
class BiomassPredictorApp:
|
24 |
-
"""Gradio app for biomass prediction"""
|
25 |
|
26 |
def __init__(self, model_repo="pokkiri/biomass-model"):
|
|
|
27 |
self.model = None
|
28 |
self.package = None
|
|
|
29 |
self.model_repo = model_repo
|
30 |
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
31 |
|
|
|
|
|
|
|
32 |
# Load the model
|
33 |
self.load_model()
|
34 |
|
35 |
def load_model(self):
|
36 |
"""Load the model and preprocessing pipeline from HuggingFace Hub"""
|
37 |
try:
|
38 |
-
|
|
|
|
|
39 |
model_path = hf_hub_download(repo_id=self.model_repo, filename="model.pt")
|
40 |
package_path = hf_hub_download(repo_id=self.model_repo, filename="model_package.pkl")
|
41 |
|
42 |
# Load package with metadata
|
43 |
self.package = joblib.load(package_path)
|
|
|
|
|
44 |
n_features = self.package['n_features']
|
|
|
45 |
|
46 |
# Initialize model
|
47 |
self.model = StableResNet(n_features=n_features)
|
@@ -49,17 +73,33 @@ class BiomassPredictorApp:
|
|
49 |
self.model.to(self.device)
|
50 |
self.model.eval()
|
51 |
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
|
56 |
return True
|
57 |
except Exception as e:
|
58 |
-
|
|
|
|
|
59 |
return False
|
60 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
def predict_biomass(self, image_file, display_type="heatmap"):
|
62 |
"""Predict biomass from a satellite image"""
|
|
|
|
|
|
|
63 |
try:
|
64 |
# Create a temporary file to save the uploaded file
|
65 |
with tempfile.NamedTemporaryFile(suffix='.tif', delete=False) as tmp_file:
|
@@ -67,10 +107,14 @@ class BiomassPredictorApp:
|
|
67 |
with open(image_file.name, 'rb') as f:
|
68 |
tmp_file.write(f.read())
|
69 |
|
|
|
|
|
|
|
|
|
70 |
try:
|
71 |
import rasterio
|
72 |
except ImportError:
|
73 |
-
return None, "Error: rasterio is required but not installed."
|
74 |
|
75 |
# Open the image file
|
76 |
with rasterio.open(tmp_path) as src:
|
@@ -79,29 +123,39 @@ class BiomassPredictorApp:
|
|
79 |
transform = src.transform
|
80 |
crs = src.crs
|
81 |
|
82 |
-
#
|
83 |
if image.shape[0] < self.package['n_features']:
|
84 |
-
return None, f"Error: Image has {image.shape[0]} bands, but model expects at least
|
|
|
85 |
|
86 |
-
|
87 |
|
88 |
# Process in chunks to avoid memory issues
|
89 |
-
chunk_size = 1000
|
90 |
predictions = np.zeros((height, width), dtype=np.float32)
|
91 |
|
92 |
# Create mask for valid pixels (not NaN or Inf)
|
93 |
valid_mask = np.all(np.isfinite(image), axis=0)
|
94 |
|
|
|
|
|
|
|
|
|
95 |
# Process image in chunks
|
|
|
|
|
|
|
96 |
for y_start in range(0, height, chunk_size):
|
97 |
y_end = min(y_start + chunk_size, height)
|
98 |
|
99 |
for x_start in range(0, width, chunk_size):
|
100 |
x_end = min(x_start + chunk_size, width)
|
|
|
101 |
|
102 |
# Get chunk mask
|
103 |
chunk_mask = valid_mask[y_start:y_end, x_start:x_end]
|
104 |
if not np.any(chunk_mask):
|
|
|
105 |
continue
|
106 |
|
107 |
# Extract valid pixels
|
@@ -124,42 +178,58 @@ class BiomassPredictorApp:
|
|
124 |
batch_predictions = self.model(batch_tensor).cpu().numpy()
|
125 |
|
126 |
# Convert from log scale if needed
|
127 |
-
if self.package
|
128 |
-
|
|
|
129 |
batch_predictions = np.maximum(batch_predictions, 0) # Ensure non-negative
|
130 |
|
131 |
# Insert predictions back into the image
|
132 |
for idx, (i, j) in enumerate(zip(valid_y, valid_x)):
|
133 |
predictions[y_start+i, x_start+j] = batch_predictions[idx]
|
134 |
-
|
135 |
-
|
136 |
-
os.unlink(tmp_path)
|
137 |
|
138 |
# Create visualization
|
|
|
139 |
plt.figure(figsize=(12, 8))
|
140 |
|
141 |
if display_type == "heatmap":
|
142 |
# Create heatmap
|
143 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
144 |
plt.colorbar(label='Biomass (Mg/ha)')
|
145 |
plt.title('Predicted Above-Ground Biomass')
|
|
|
146 |
|
147 |
elif display_type == "rgb_overlay":
|
148 |
# Create RGB + overlay
|
149 |
if image.shape[0] >= 3:
|
150 |
# Use first 3 bands as RGB
|
151 |
rgb = image[[0, 1, 2]].transpose(1, 2, 0)
|
152 |
-
rgb = np.clip((rgb - np.percentile(rgb, 2)) / (np.percentile(rgb, 98) - np.percentile(rgb, 2)), 0, 1)
|
153 |
|
154 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
155 |
|
156 |
# Create mask for overlay (where we have predictions)
|
157 |
-
mask = ~np.isclose(predictions, 0)
|
158 |
overlay = np.zeros((height, width, 4))
|
159 |
|
160 |
# Create colormap for biomass
|
161 |
-
norm = colors.Normalize(
|
162 |
-
|
|
|
|
|
163 |
cmap = plt.cm.viridis
|
164 |
|
165 |
# Apply colormap
|
@@ -170,42 +240,63 @@ class BiomassPredictorApp:
|
|
170 |
plt.colorbar(plt.cm.ScalarMappable(norm=norm, cmap=cmap),
|
171 |
label='Biomass (Mg/ha)')
|
172 |
plt.title('Biomass Prediction Overlay')
|
|
|
173 |
else:
|
174 |
-
|
|
|
|
|
175 |
plt.colorbar(label='Biomass (Mg/ha)')
|
176 |
plt.title('Predicted Above-Ground Biomass')
|
|
|
177 |
|
178 |
# Save figure to bytes buffer
|
179 |
buf = io.BytesIO()
|
180 |
plt.savefig(buf, format='png', dpi=150, bbox_inches='tight')
|
181 |
buf.seek(0)
|
|
|
182 |
|
183 |
-
#
|
184 |
valid_predictions = predictions[valid_mask]
|
185 |
stats = {
|
186 |
'Mean Biomass': f"{np.mean(valid_predictions):.2f} Mg/ha",
|
187 |
'Median Biomass': f"{np.median(valid_predictions):.2f} Mg/ha",
|
188 |
'Min Biomass': f"{np.min(valid_predictions):.2f} Mg/ha",
|
189 |
-
'Max Biomass': f"{np.max(valid_predictions):.2f} Mg/ha"
|
190 |
-
'Total Biomass': f"{np.sum(valid_predictions) * (transform[0] * transform[0]) / 10000:.2f} Mg",
|
191 |
-
'Area': f"{np.sum(valid_mask) * (transform[0] * transform[0]) / 10000:.2f} hectares"
|
192 |
}
|
193 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
194 |
# Format statistics as markdown
|
195 |
stats_md = "### Biomass Statistics\n\n"
|
196 |
stats_md += "| Metric | Value |\n|--------|-------|\n"
|
197 |
for k, v in stats.items():
|
198 |
stats_md += f"| {k} | {v} |\n"
|
199 |
|
200 |
-
#
|
201 |
-
|
|
|
|
|
|
|
202 |
|
203 |
# Return visualization and statistics
|
204 |
return Image.open(buf), stats_md
|
205 |
|
206 |
except Exception as e:
|
|
|
|
|
|
|
207 |
import traceback
|
208 |
-
|
|
|
|
|
|
|
209 |
|
210 |
def create_interface(self):
|
211 |
"""Create Gradio interface"""
|
@@ -220,7 +311,7 @@ class BiomassPredictorApp:
|
|
220 |
""")
|
221 |
|
222 |
with gr.Row():
|
223 |
-
with gr.Column():
|
224 |
input_image = gr.File(
|
225 |
label="Upload Satellite Image (GeoTIFF)",
|
226 |
file_types=[".tif", ".tiff"]
|
@@ -232,9 +323,9 @@ class BiomassPredictorApp:
|
|
232 |
label="Display Type"
|
233 |
)
|
234 |
|
235 |
-
submit_btn = gr.Button("Generate Biomass Prediction")
|
236 |
|
237 |
-
with gr.Column():
|
238 |
output_image = gr.Image(
|
239 |
label="Biomass Prediction Map",
|
240 |
type="pil"
|
@@ -245,7 +336,7 @@ class BiomassPredictorApp:
|
|
245 |
)
|
246 |
|
247 |
with gr.Accordion("About", open=False):
|
248 |
-
gr.Markdown(
|
249 |
## About This Model
|
250 |
|
251 |
This biomass prediction model uses the StableResNet architecture to predict above-ground biomass from satellite imagery.
|
@@ -255,9 +346,9 @@ class BiomassPredictorApp:
|
|
255 |
- Architecture: StableResNet
|
256 |
- Input: Multi-spectral satellite imagery
|
257 |
- Output: Above-ground biomass (Mg/ha)
|
258 |
-
- Creator:
|
259 |
-
- Date:
|
260 |
-
- Model Repository: [
|
261 |
|
262 |
### How It Works
|
263 |
|
@@ -267,6 +358,27 @@ class BiomassPredictorApp:
|
|
267 |
4. Results are visualized as a heatmap or RGB overlay
|
268 |
""")
|
269 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
270 |
submit_btn.click(
|
271 |
fn=self.predict_biomass,
|
272 |
inputs=[input_image, display_type],
|
@@ -277,9 +389,19 @@ class BiomassPredictorApp:
|
|
277 |
|
278 |
def launch_app():
|
279 |
"""Launch the Gradio app"""
|
280 |
-
|
281 |
-
|
282 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
283 |
|
284 |
if __name__ == "__main__":
|
285 |
launch_app()
|
|
|
5 |
Author: najahpokkiri
|
6 |
Date: 2025-05-17
|
7 |
"""
|
8 |
+
"""
|
9 |
+
Biomass Prediction Gradio App
|
10 |
+
Author: najahpokkiri
|
11 |
+
Date: 2025-05-17
|
12 |
+
|
13 |
+
This app allows users to predict above-ground biomass from satellite imagery
|
14 |
+
using a trained StableResNet model.
|
15 |
+
"""
|
16 |
import os
|
17 |
+
import sys
|
18 |
import torch
|
19 |
import numpy as np
|
20 |
import gradio as gr
|
|
|
24 |
import matplotlib.colors as colors
|
25 |
from PIL import Image
|
26 |
import io
|
27 |
+
import logging
|
28 |
from huggingface_hub import hf_hub_download
|
29 |
|
30 |
+
# Configure logger
|
31 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
32 |
+
logger = logging.getLogger(__name__)
|
33 |
+
|
34 |
# Import model architecture
|
35 |
from model import StableResNet
|
36 |
|
37 |
class BiomassPredictorApp:
|
38 |
+
"""Gradio app for biomass prediction from satellite imagery"""
|
39 |
|
40 |
def __init__(self, model_repo="pokkiri/biomass-model"):
|
41 |
+
"""Initialize the app with model repository information"""
|
42 |
self.model = None
|
43 |
self.package = None
|
44 |
+
self.feature_names = []
|
45 |
self.model_repo = model_repo
|
46 |
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
47 |
|
48 |
+
# Cache for storing temporary files
|
49 |
+
self.temp_files = []
|
50 |
+
|
51 |
# Load the model
|
52 |
self.load_model()
|
53 |
|
54 |
def load_model(self):
|
55 |
"""Load the model and preprocessing pipeline from HuggingFace Hub"""
|
56 |
try:
|
57 |
+
logger.info(f"Loading model from {self.model_repo}")
|
58 |
+
|
59 |
+
# Download model files from HuggingFace
|
60 |
model_path = hf_hub_download(repo_id=self.model_repo, filename="model.pt")
|
61 |
package_path = hf_hub_download(repo_id=self.model_repo, filename="model_package.pkl")
|
62 |
|
63 |
# Load package with metadata
|
64 |
self.package = joblib.load(package_path)
|
65 |
+
|
66 |
+
# Extract information from package
|
67 |
n_features = self.package['n_features']
|
68 |
+
self.feature_names = self.package.get('feature_names', [f"feature_{i}" for i in range(n_features)])
|
69 |
|
70 |
# Initialize model
|
71 |
self.model = StableResNet(n_features=n_features)
|
|
|
73 |
self.model.to(self.device)
|
74 |
self.model.eval()
|
75 |
|
76 |
+
logger.info(f"Model loaded successfully from {self.model_repo}")
|
77 |
+
logger.info(f"Number of features: {n_features}")
|
78 |
+
logger.info(f"Using device: {self.device}")
|
79 |
|
80 |
return True
|
81 |
except Exception as e:
|
82 |
+
logger.error(f"Error loading model: {e}")
|
83 |
+
import traceback
|
84 |
+
logger.error(traceback.format_exc())
|
85 |
return False
|
86 |
|
87 |
+
def cleanup(self):
|
88 |
+
"""Clean up temporary files"""
|
89 |
+
for tmp_path in self.temp_files:
|
90 |
+
try:
|
91 |
+
if os.path.exists(tmp_path):
|
92 |
+
os.unlink(tmp_path)
|
93 |
+
except Exception as e:
|
94 |
+
logger.warning(f"Failed to remove temporary file {tmp_path}: {e}")
|
95 |
+
|
96 |
+
self.temp_files = []
|
97 |
+
|
98 |
def predict_biomass(self, image_file, display_type="heatmap"):
|
99 |
"""Predict biomass from a satellite image"""
|
100 |
+
if self.model is None:
|
101 |
+
return None, "Error: Model not loaded. Please check logs for details."
|
102 |
+
|
103 |
try:
|
104 |
# Create a temporary file to save the uploaded file
|
105 |
with tempfile.NamedTemporaryFile(suffix='.tif', delete=False) as tmp_file:
|
|
|
107 |
with open(image_file.name, 'rb') as f:
|
108 |
tmp_file.write(f.read())
|
109 |
|
110 |
+
# Add to list for cleanup later
|
111 |
+
self.temp_files.append(tmp_path)
|
112 |
+
|
113 |
+
# Ensure rasterio is available
|
114 |
try:
|
115 |
import rasterio
|
116 |
except ImportError:
|
117 |
+
return None, "Error: rasterio is required but not installed. Please install with: pip install rasterio"
|
118 |
|
119 |
# Open the image file
|
120 |
with rasterio.open(tmp_path) as src:
|
|
|
123 |
transform = src.transform
|
124 |
crs = src.crs
|
125 |
|
126 |
+
# Validate image dimensions
|
127 |
if image.shape[0] < self.package['n_features']:
|
128 |
+
return None, (f"Error: Image has {image.shape[0]} bands, but model expects at least "
|
129 |
+
f"{self.package['n_features']} features.")
|
130 |
|
131 |
+
logger.info(f"Processing image: {height}x{width} pixels, {image.shape[0]} bands")
|
132 |
|
133 |
# Process in chunks to avoid memory issues
|
134 |
+
chunk_size = min(1000, height, width) # Adjust chunk size for smaller images
|
135 |
predictions = np.zeros((height, width), dtype=np.float32)
|
136 |
|
137 |
# Create mask for valid pixels (not NaN or Inf)
|
138 |
valid_mask = np.all(np.isfinite(image), axis=0)
|
139 |
|
140 |
+
# Show progress indicator
|
141 |
+
progress_text = f"Processing {height}x{width} image..."
|
142 |
+
logger.info(progress_text)
|
143 |
+
|
144 |
# Process image in chunks
|
145 |
+
total_chunks = ((height + chunk_size - 1) // chunk_size) * ((width + chunk_size - 1) // chunk_size)
|
146 |
+
chunk_count = 0
|
147 |
+
|
148 |
for y_start in range(0, height, chunk_size):
|
149 |
y_end = min(y_start + chunk_size, height)
|
150 |
|
151 |
for x_start in range(0, width, chunk_size):
|
152 |
x_end = min(x_start + chunk_size, width)
|
153 |
+
chunk_count += 1
|
154 |
|
155 |
# Get chunk mask
|
156 |
chunk_mask = valid_mask[y_start:y_end, x_start:x_end]
|
157 |
if not np.any(chunk_mask):
|
158 |
+
logger.info(f"Skipping chunk {chunk_count}/{total_chunks} (no valid pixels)")
|
159 |
continue
|
160 |
|
161 |
# Extract valid pixels
|
|
|
178 |
batch_predictions = self.model(batch_tensor).cpu().numpy()
|
179 |
|
180 |
# Convert from log scale if needed
|
181 |
+
if self.package.get('use_log_transform', False):
|
182 |
+
epsilon = self.package.get('epsilon', 1.0)
|
183 |
+
batch_predictions = np.exp(batch_predictions) - epsilon
|
184 |
batch_predictions = np.maximum(batch_predictions, 0) # Ensure non-negative
|
185 |
|
186 |
# Insert predictions back into the image
|
187 |
for idx, (i, j) in enumerate(zip(valid_y, valid_x)):
|
188 |
predictions[y_start+i, x_start+j] = batch_predictions[idx]
|
189 |
+
|
190 |
+
logger.info(f"Processed chunk {chunk_count}/{total_chunks}")
|
|
|
191 |
|
192 |
# Create visualization
|
193 |
+
logger.info("Creating visualization...")
|
194 |
plt.figure(figsize=(12, 8))
|
195 |
|
196 |
if display_type == "heatmap":
|
197 |
# Create heatmap
|
198 |
+
# Use masked array for better visualization
|
199 |
+
masked_predictions = np.ma.masked_where(~valid_mask, predictions)
|
200 |
+
|
201 |
+
# Set min/max values based on percentiles for better contrast
|
202 |
+
vmin = np.percentile(predictions[valid_mask], 1)
|
203 |
+
vmax = np.percentile(predictions[valid_mask], 99)
|
204 |
+
|
205 |
+
plt.imshow(masked_predictions, cmap='viridis', vmin=vmin, vmax=vmax)
|
206 |
plt.colorbar(label='Biomass (Mg/ha)')
|
207 |
plt.title('Predicted Above-Ground Biomass')
|
208 |
+
plt.axis('off') # Hide axes for cleaner visualization
|
209 |
|
210 |
elif display_type == "rgb_overlay":
|
211 |
# Create RGB + overlay
|
212 |
if image.shape[0] >= 3:
|
213 |
# Use first 3 bands as RGB
|
214 |
rgb = image[[0, 1, 2]].transpose(1, 2, 0)
|
|
|
215 |
|
216 |
+
# Enhance contrast with percentile-based normalization
|
217 |
+
p2 = np.percentile(rgb[np.isfinite(rgb)], 2)
|
218 |
+
p98 = np.percentile(rgb[np.isfinite(rgb)], 98)
|
219 |
+
rgb_norm = np.clip((rgb - p2) / (p98 - p2), 0, 1)
|
220 |
+
|
221 |
+
# Display RGB image
|
222 |
+
plt.imshow(rgb_norm)
|
223 |
|
224 |
# Create mask for overlay (where we have predictions)
|
225 |
+
mask = valid_mask & (~np.isclose(predictions, 0))
|
226 |
overlay = np.zeros((height, width, 4))
|
227 |
|
228 |
# Create colormap for biomass
|
229 |
+
norm = colors.Normalize(
|
230 |
+
vmin=np.percentile(predictions[mask], 5),
|
231 |
+
vmax=np.percentile(predictions[mask], 95)
|
232 |
+
)
|
233 |
cmap = plt.cm.viridis
|
234 |
|
235 |
# Apply colormap
|
|
|
240 |
plt.colorbar(plt.cm.ScalarMappable(norm=norm, cmap=cmap),
|
241 |
label='Biomass (Mg/ha)')
|
242 |
plt.title('Biomass Prediction Overlay')
|
243 |
+
plt.axis('off')
|
244 |
else:
|
245 |
+
# Fallback to regular heatmap if not enough bands for RGB
|
246 |
+
masked_predictions = np.ma.masked_where(~valid_mask, predictions)
|
247 |
+
plt.imshow(masked_predictions, cmap='viridis')
|
248 |
plt.colorbar(label='Biomass (Mg/ha)')
|
249 |
plt.title('Predicted Above-Ground Biomass')
|
250 |
+
plt.axis('off')
|
251 |
|
252 |
# Save figure to bytes buffer
|
253 |
buf = io.BytesIO()
|
254 |
plt.savefig(buf, format='png', dpi=150, bbox_inches='tight')
|
255 |
buf.seek(0)
|
256 |
+
plt.close()
|
257 |
|
258 |
+
# Calculate summary statistics
|
259 |
valid_predictions = predictions[valid_mask]
|
260 |
stats = {
|
261 |
'Mean Biomass': f"{np.mean(valid_predictions):.2f} Mg/ha",
|
262 |
'Median Biomass': f"{np.median(valid_predictions):.2f} Mg/ha",
|
263 |
'Min Biomass': f"{np.min(valid_predictions):.2f} Mg/ha",
|
264 |
+
'Max Biomass': f"{np.max(valid_predictions):.2f} Mg/ha"
|
|
|
|
|
265 |
}
|
266 |
|
267 |
+
# Add area and total biomass if transform is available
|
268 |
+
if transform is not None:
|
269 |
+
pixel_area_m2 = abs(transform[0] * transform[4]) # Assuming square pixels
|
270 |
+
total_biomass = np.sum(valid_predictions) * (pixel_area_m2 / 10000) # Convert to hectares
|
271 |
+
area_hectares = np.sum(valid_mask) * (pixel_area_m2 / 10000)
|
272 |
+
|
273 |
+
stats['Total Biomass'] = f"{total_biomass:.2f} Mg"
|
274 |
+
stats['Area'] = f"{area_hectares:.2f} hectares"
|
275 |
+
|
276 |
# Format statistics as markdown
|
277 |
stats_md = "### Biomass Statistics\n\n"
|
278 |
stats_md += "| Metric | Value |\n|--------|-------|\n"
|
279 |
for k, v in stats.items():
|
280 |
stats_md += f"| {k} | {v} |\n"
|
281 |
|
282 |
+
# Add processing info
|
283 |
+
stats_md += f"\n\n*Processed {np.sum(valid_mask):,} valid pixels*"
|
284 |
+
|
285 |
+
# Cleanup temporary files
|
286 |
+
self.cleanup()
|
287 |
|
288 |
# Return visualization and statistics
|
289 |
return Image.open(buf), stats_md
|
290 |
|
291 |
except Exception as e:
|
292 |
+
# Ensure cleanup even on error
|
293 |
+
self.cleanup()
|
294 |
+
|
295 |
import traceback
|
296 |
+
logger.error(f"Error predicting biomass: {e}")
|
297 |
+
logger.error(traceback.format_exc())
|
298 |
+
|
299 |
+
return None, f"Error predicting biomass: {str(e)}\n\nPlease check logs for details."
|
300 |
|
301 |
def create_interface(self):
|
302 |
"""Create Gradio interface"""
|
|
|
311 |
""")
|
312 |
|
313 |
with gr.Row():
|
314 |
+
with gr.Column(scale=1):
|
315 |
input_image = gr.File(
|
316 |
label="Upload Satellite Image (GeoTIFF)",
|
317 |
file_types=[".tif", ".tiff"]
|
|
|
323 |
label="Display Type"
|
324 |
)
|
325 |
|
326 |
+
submit_btn = gr.Button("Generate Biomass Prediction", variant="primary")
|
327 |
|
328 |
+
with gr.Column(scale=2):
|
329 |
output_image = gr.Image(
|
330 |
label="Biomass Prediction Map",
|
331 |
type="pil"
|
|
|
336 |
)
|
337 |
|
338 |
with gr.Accordion("About", open=False):
|
339 |
+
gr.Markdown("""
|
340 |
## About This Model
|
341 |
|
342 |
This biomass prediction model uses the StableResNet architecture to predict above-ground biomass from satellite imagery.
|
|
|
346 |
- Architecture: StableResNet
|
347 |
- Input: Multi-spectral satellite imagery
|
348 |
- Output: Above-ground biomass (Mg/ha)
|
349 |
+
- Creator: najahpokkiri
|
350 |
+
- Date: 2025-05-17
|
351 |
+
- Model Repository: [pokkiri/biomass-model](https://huggingface.co/pokkiri/biomass-model)
|
352 |
|
353 |
### How It Works
|
354 |
|
|
|
358 |
4. Results are visualized as a heatmap or RGB overlay
|
359 |
""")
|
360 |
|
361 |
+
with gr.Accordion("Examples", open=False):
|
362 |
+
gr.Markdown("""
|
363 |
+
### Example Data
|
364 |
+
|
365 |
+
To try the model, you can use sample GeoTIFF files with the following characteristics:
|
366 |
+
|
367 |
+
- Multi-band satellite imagery (Sentinel-2, Landsat, etc.)
|
368 |
+
- Contains bands in the proper order (see documentation)
|
369 |
+
- Images should be relatively small (< 2000x2000 pixels) for faster processing
|
370 |
+
|
371 |
+
You can find sample data at:
|
372 |
+
- [Earth Explorer](https://earthexplorer.usgs.gov/)
|
373 |
+
- [Copernicus Open Access Hub](https://scihub.copernicus.eu/)
|
374 |
+
- [Planetary Computer](https://planetarycomputer.microsoft.com/)
|
375 |
+
""")
|
376 |
+
|
377 |
+
# Add a warning if model failed to load
|
378 |
+
if self.model is None:
|
379 |
+
gr.Warning("⚠️ Model failed to load. The app may not work correctly. Check logs for details.")
|
380 |
+
|
381 |
+
# Connect the submit button
|
382 |
submit_btn.click(
|
383 |
fn=self.predict_biomass,
|
384 |
inputs=[input_image, display_type],
|
|
|
389 |
|
390 |
def launch_app():
|
391 |
"""Launch the Gradio app"""
|
392 |
+
try:
|
393 |
+
# Create app instance
|
394 |
+
app = BiomassPredictorApp()
|
395 |
+
|
396 |
+
# Create interface
|
397 |
+
interface = app.create_interface()
|
398 |
+
|
399 |
+
# Launch interface
|
400 |
+
interface.launch(share=True)
|
401 |
+
except Exception as e:
|
402 |
+
logger.error(f"Error launching app: {e}")
|
403 |
+
import traceback
|
404 |
+
logger.error(traceback.format_exc())
|
405 |
|
406 |
if __name__ == "__main__":
|
407 |
launch_app()
|