HoeioUser commited on
Commit
6be8089
ยท
verified ยท
1 Parent(s): bdddd2f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -50
app.py CHANGED
@@ -212,57 +212,78 @@ class ClimatePredictor:
212
  transforms.Normalize(mean=[0.5], std=[0.5])
213
  ])
214
 
 
 
 
 
 
 
 
215
  def predict_from_inputs(self, rgb_image, ndvi_image, terrain_image,
216
  elevation_data, wind_speed, wind_direction,
217
  temperature, humidity):
218
  """Gradio ์ธํ„ฐํŽ˜์ด์Šค์šฉ ์˜ˆ์ธก ํ•จ์ˆ˜"""
219
- # ์ด๋ฏธ์ง€ ์ „์ฒ˜๋ฆฌ
220
- rgb_tensor = self.rgb_transform(Image.fromarray(rgb_image)).unsqueeze(0)
221
- ndvi_tensor = self.single_channel_transform(Image.fromarray(ndvi_image)).unsqueeze(0)
222
- terrain_tensor = self.single_channel_transform(Image.fromarray(terrain_image)).unsqueeze(0)
223
-
224
- # ๊ณ ๋„ ๋ฐ์ดํ„ฐ ์ฒ˜๋ฆฌ
225
- elevation_tensor = torch.from_numpy(elevation_data).float().unsqueeze(0).unsqueeze(0)
226
- elevation_tensor = (elevation_tensor - elevation_tensor.min()) / (elevation_tensor.max() - elevation_tensor.min())
227
-
228
- # ๊ธฐ์ƒ ๋ฐ์ดํ„ฐ ์ฒ˜๋ฆฌ
229
- weather_features = np.array([wind_speed, wind_direction, temperature, humidity])
230
- weather_features = (weather_features - weather_features.min()) / (weather_features.max() - weather_features.min())
231
- weather_features = torch.tensor(weather_features, dtype=torch.float32).unsqueeze(0)
232
-
233
- # ๋””๋ฐ”์ด์Šค๋กœ ์ด๋™
234
- sample = {
235
- 'rgb': rgb_tensor.to(self.device),
236
- 'ndvi': ndvi_tensor.to(self.device),
237
- 'terrain': terrain_tensor.to(self.device),
238
- 'elevation': elevation_tensor.to(self.device),
239
- 'weather_features': weather_features.to(self.device)
240
- }
241
-
242
- # ์˜ˆ์ธก
243
- with torch.no_grad():
244
- wind_pred, solar_pred = self.model(sample)
245
-
246
- # ๊ฒฐ๊ณผ๋ฅผ numpy ๋ฐฐ์—ด๋กœ ๋ณ€ํ™˜
247
- wind_map = wind_pred.cpu().numpy()[0, 0]
248
- solar_map = solar_pred.cpu().numpy()[0, 0]
249
-
250
- # ๊ฒฐ๊ณผ ์‹œ๊ฐํ™”
251
- fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
252
-
253
- # ํ’๋ ฅ ๋ฐœ์ „ ์ž ์žฌ๋Ÿ‰ ์‹œ๊ฐํ™”
254
- sns.heatmap(wind_map, ax=ax1, cmap='YlOrRd', cbar_kws={'label': 'Wind Power Potential'})
255
- ax1.set_title('Wind Power Potential Map')
256
-
257
- # ํƒœ์–‘๊ด‘ ๋ฐœ์ „ ์ž ์žฌ๋Ÿ‰ ์‹œ๊ฐํ™”
258
- sns.heatmap(solar_map, ax=ax2, cmap='YlOrRd', cbar_kws={'label': 'Solar Power Potential'})
259
- ax2.set_title('Solar Power Potential Map')
260
-
261
- plt.tight_layout()
262
-
263
- return fig
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
264
 
265
- # Gradio ์ธํ„ฐํŽ˜์ด์Šค ์ƒ์„ฑ
266
  def create_gradio_interface():
267
  predictor = ClimatePredictor('best_model.pth')
268
 
@@ -282,8 +303,8 @@ def create_gradio_interface():
282
  fn=predict_and_visualize,
283
  inputs=[
284
  gr.Image(label="RGB Satellite Image", type="numpy"),
285
- gr.Image(label="NDVI Image", type="numpy"),
286
- gr.Image(label="Terrain Map", type="numpy"),
287
  gr.File(label="Elevation Data (NPY file)"),
288
  gr.Number(label="Wind Speed (m/s)", value=5.0),
289
  gr.Number(label="Wind Direction (degrees)", value=180.0),
@@ -292,7 +313,8 @@ def create_gradio_interface():
292
  ],
293
  outputs=gr.Plot(label="Prediction Results"),
294
  title="Renewable Energy Potential Predictor",
295
- description="Upload satellite imagery and environmental data to predict wind and solar power potential.",
 
296
  examples=[
297
  [
298
  "examples/rgb_example.png",
@@ -305,6 +327,7 @@ def create_gradio_interface():
305
  )
306
  return interface
307
 
 
308
  if __name__ == "__main__":
309
  interface = create_gradio_interface()
310
- interface.launch()
 
212
  transforms.Normalize(mean=[0.5], std=[0.5])
213
  ])
214
 
215
+ def convert_to_single_channel(self, image_array):
216
+ """RGB ์ด๋ฏธ์ง€๋ฅผ ๋‹จ์ผ ์ฑ„๋„๋กœ ๋ณ€ํ™˜"""
217
+ if len(image_array.shape) == 3:
218
+ # RGB to grayscale conversion
219
+ return np.dot(image_array[...,:3], [0.2989, 0.5870, 0.1140])
220
+ return image_array
221
+
222
  def predict_from_inputs(self, rgb_image, ndvi_image, terrain_image,
223
  elevation_data, wind_speed, wind_direction,
224
  temperature, humidity):
225
  """Gradio ์ธํ„ฐํŽ˜์ด์Šค์šฉ ์˜ˆ์ธก ํ•จ์ˆ˜"""
226
+ try:
227
+ # RGB ์ด๋ฏธ์ง€ ์ „์ฒ˜๋ฆฌ
228
+ rgb_tensor = self.rgb_transform(Image.fromarray(rgb_image)).unsqueeze(0)
229
+
230
+ # NDVI ์ด๋ฏธ์ง€๋ฅผ ๋‹จ์ผ ์ฑ„๋„๋กœ ๋ณ€ํ™˜ ํ›„ ์ „์ฒ˜๋ฆฌ
231
+ ndvi_gray = self.convert_to_single_channel(ndvi_image)
232
+ ndvi_tensor = self.single_channel_transform(Image.fromarray(ndvi_gray.astype(np.uint8))).unsqueeze(0)
233
+
234
+ # Terrain ์ด๋ฏธ์ง€๋ฅผ ๋‹จ์ผ ์ฑ„๋„๋กœ ๋ณ€ํ™˜ ํ›„ ์ „์ฒ˜๋ฆฌ
235
+ terrain_gray = self.convert_to_single_channel(terrain_image)
236
+ terrain_tensor = self.single_channel_transform(Image.fromarray(terrain_gray.astype(np.uint8))).unsqueeze(0)
237
+
238
+ # Print shapes for debugging
239
+ print(f"RGB tensor shape: {rgb_tensor.shape}")
240
+ print(f"NDVI tensor shape: {ndvi_tensor.shape}")
241
+ print(f"Terrain tensor shape: {terrain_tensor.shape}")
242
+
243
+ # ๊ณ ๋„ ๋ฐ์ดํ„ฐ ์ฒ˜๋ฆฌ
244
+ elevation_tensor = torch.from_numpy(elevation_data).float().unsqueeze(0).unsqueeze(0)
245
+ elevation_tensor = (elevation_tensor - elevation_tensor.min()) / (elevation_tensor.max() - elevation_tensor.min())
246
+
247
+ # ๊ธฐ์ƒ ๋ฐ์ดํ„ฐ ์ฒ˜๋ฆฌ
248
+ weather_features = np.array([wind_speed, wind_direction, temperature, humidity])
249
+ weather_features = (weather_features - weather_features.min()) / (weather_features.max() - weather_features.min())
250
+ weather_features = torch.tensor(weather_features, dtype=torch.float32).unsqueeze(0)
251
+
252
+ # ๋””๋ฐ”์ด์Šค๋กœ ์ด๋™
253
+ sample = {
254
+ 'rgb': rgb_tensor.to(self.device),
255
+ 'ndvi': ndvi_tensor.to(self.device),
256
+ 'terrain': terrain_tensor.to(self.device),
257
+ 'elevation': elevation_tensor.to(self.device),
258
+ 'weather_features': weather_features.to(self.device)
259
+ }
260
+
261
+ # ์˜ˆ์ธก
262
+ with torch.no_grad():
263
+ wind_pred, solar_pred = self.model(sample)
264
+
265
+ # ๊ฒฐ๊ณผ๋ฅผ numpy ๋ฐฐ์—ด๋กœ ๋ณ€ํ™˜
266
+ wind_map = wind_pred.cpu().numpy()[0, 0]
267
+ solar_map = solar_pred.cpu().numpy()[0, 0]
268
+
269
+ # ๊ฒฐ๊ณผ ์‹œ๊ฐํ™”
270
+ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
271
+
272
+ # ํ’๋ ฅ ๋ฐœ์ „ ์ž ์žฌ๋Ÿ‰ ์‹œ๊ฐํ™”
273
+ sns.heatmap(wind_map, ax=ax1, cmap='YlOrRd', cbar_kws={'label': 'Wind Power Potential'})
274
+ ax1.set_title('Wind Power Potential Map')
275
+
276
+ # ํƒœ์–‘๊ด‘ ๋ฐœ์ „ ์ž ์žฌ๋Ÿ‰ ์‹œ๊ฐํ™”
277
+ sns.heatmap(solar_map, ax=ax2, cmap='YlOrRd', cbar_kws={'label': 'Solar Power Potential'})
278
+ ax2.set_title('Solar Power Potential Map')
279
+
280
+ plt.tight_layout()
281
+
282
+ return fig
283
+ except Exception as e:
284
+ print(f"Error in prediction: {str(e)}")
285
+ raise e
286
 
 
287
  def create_gradio_interface():
288
  predictor = ClimatePredictor('best_model.pth')
289
 
 
303
  fn=predict_and_visualize,
304
  inputs=[
305
  gr.Image(label="RGB Satellite Image", type="numpy"),
306
+ gr.Image(label="NDVI Image (will be converted to grayscale)", type="numpy"),
307
+ gr.Image(label="Terrain Map (will be converted to grayscale)", type="numpy"),
308
  gr.File(label="Elevation Data (NPY file)"),
309
  gr.Number(label="Wind Speed (m/s)", value=5.0),
310
  gr.Number(label="Wind Direction (degrees)", value=180.0),
 
313
  ],
314
  outputs=gr.Plot(label="Prediction Results"),
315
  title="Renewable Energy Potential Predictor",
316
+ description="""Upload satellite imagery and environmental data to predict wind and solar power potential.
317
+ Note: NDVI and Terrain images will be automatically converted to grayscale.""",
318
  examples=[
319
  [
320
  "examples/rgb_example.png",
 
327
  )
328
  return interface
329
 
330
+ # Hugging Face Spaces์—์„œ ์•ฑ ์‹คํ–‰
331
  if __name__ == "__main__":
332
  interface = create_gradio_interface()
333
+ interface.launch()