HoeioUser commited on
Commit
6fa0406
ยท
verified ยท
1 Parent(s): 6be8089

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -26
app.py CHANGED
@@ -215,7 +215,6 @@ class ClimatePredictor:
215
  def convert_to_single_channel(self, image_array):
216
  """RGB ์ด๋ฏธ์ง€๋ฅผ ๋‹จ์ผ ์ฑ„๋„๋กœ ๋ณ€ํ™˜"""
217
  if len(image_array.shape) == 3:
218
- # RGB to grayscale conversion
219
  return np.dot(image_array[...,:3], [0.2989, 0.5870, 0.1140])
220
  return image_array
221
 
@@ -235,11 +234,6 @@ class ClimatePredictor:
235
  terrain_gray = self.convert_to_single_channel(terrain_image)
236
  terrain_tensor = self.single_channel_transform(Image.fromarray(terrain_gray.astype(np.uint8))).unsqueeze(0)
237
 
238
- # Print shapes for debugging
239
- print(f"RGB tensor shape: {rgb_tensor.shape}")
240
- print(f"NDVI tensor shape: {ndvi_tensor.shape}")
241
- print(f"Terrain tensor shape: {terrain_tensor.shape}")
242
-
243
  # ๊ณ ๋„ ๋ฐ์ดํ„ฐ ์ฒ˜๋ฆฌ
244
  elevation_tensor = torch.from_numpy(elevation_data).float().unsqueeze(0).unsqueeze(0)
245
  elevation_tensor = (elevation_tensor - elevation_tensor.min()) / (elevation_tensor.max() - elevation_tensor.min())
@@ -269,11 +263,9 @@ class ClimatePredictor:
269
  # ๊ฒฐ๊ณผ ์‹œ๊ฐํ™”
270
  fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
271
 
272
- # ํ’๋ ฅ ๋ฐœ์ „ ์ž ์žฌ๋Ÿ‰ ์‹œ๊ฐํ™”
273
  sns.heatmap(wind_map, ax=ax1, cmap='YlOrRd', cbar_kws={'label': 'Wind Power Potential'})
274
  ax1.set_title('Wind Power Potential Map')
275
 
276
- # ํƒœ์–‘๊ด‘ ๋ฐœ์ „ ์ž ์žฌ๋Ÿ‰ ์‹œ๊ฐํ™”
277
  sns.heatmap(solar_map, ax=ax2, cmap='YlOrRd', cbar_kws={'label': 'Solar Power Potential'})
278
  ax2.set_title('Solar Power Potential Map')
279
 
@@ -284,27 +276,72 @@ class ClimatePredictor:
284
  print(f"Error in prediction: {str(e)}")
285
  raise e
286
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
287
  def create_gradio_interface():
288
  predictor = ClimatePredictor('best_model.pth')
289
 
290
- def predict_and_visualize(rgb_image, ndvi_image, terrain_image, elevation_file,
291
  wind_speed, wind_direction, temperature, humidity):
292
- # Load elevation data
293
- elevation_data = np.load(elevation_file.name)
 
 
 
294
 
295
- # Generate prediction and visualization
296
  result = predictor.predict_from_inputs(
297
  rgb_image, ndvi_image, terrain_image, elevation_data,
298
  wind_speed, wind_direction, temperature, humidity
299
  )
300
  return result
301
 
 
 
 
 
302
  interface = gr.Interface(
303
  fn=predict_and_visualize,
304
  inputs=[
305
- gr.Image(label="RGB Satellite Image", type="numpy"),
306
- gr.Image(label="NDVI Image (will be converted to grayscale)", type="numpy"),
307
- gr.Image(label="Terrain Map (will be converted to grayscale)", type="numpy"),
308
  gr.File(label="Elevation Data (NPY file)"),
309
  gr.Number(label="Wind Speed (m/s)", value=5.0),
310
  gr.Number(label="Wind Direction (degrees)", value=180.0),
@@ -314,20 +351,12 @@ def create_gradio_interface():
314
  outputs=gr.Plot(label="Prediction Results"),
315
  title="Renewable Energy Potential Predictor",
316
  description="""Upload satellite imagery and environmental data to predict wind and solar power potential.
317
- Note: NDVI and Terrain images will be automatically converted to grayscale.""",
318
- examples=[
319
- [
320
- "examples/rgb_example.png",
321
- "examples/ndvi_example.png",
322
- "examples/terrain_example.png",
323
- "examples/elevation_example.npy",
324
- 5.0, 180.0, 25.0, 60.0
325
- ]
326
- ]
327
  )
328
  return interface
329
 
330
- # Hugging Face Spaces์—์„œ ์•ฑ ์‹คํ–‰
331
  if __name__ == "__main__":
332
  interface = create_gradio_interface()
333
  interface.launch()
 
215
  def convert_to_single_channel(self, image_array):
216
  """RGB ์ด๋ฏธ์ง€๋ฅผ ๋‹จ์ผ ์ฑ„๋„๋กœ ๋ณ€ํ™˜"""
217
  if len(image_array.shape) == 3:
 
218
  return np.dot(image_array[...,:3], [0.2989, 0.5870, 0.1140])
219
  return image_array
220
 
 
234
  terrain_gray = self.convert_to_single_channel(terrain_image)
235
  terrain_tensor = self.single_channel_transform(Image.fromarray(terrain_gray.astype(np.uint8))).unsqueeze(0)
236
 
 
 
 
 
 
237
  # ๊ณ ๋„ ๋ฐ์ดํ„ฐ ์ฒ˜๋ฆฌ
238
  elevation_tensor = torch.from_numpy(elevation_data).float().unsqueeze(0).unsqueeze(0)
239
  elevation_tensor = (elevation_tensor - elevation_tensor.min()) / (elevation_tensor.max() - elevation_tensor.min())
 
263
  # ๊ฒฐ๊ณผ ์‹œ๊ฐํ™”
264
  fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
265
 
 
266
  sns.heatmap(wind_map, ax=ax1, cmap='YlOrRd', cbar_kws={'label': 'Wind Power Potential'})
267
  ax1.set_title('Wind Power Potential Map')
268
 
 
269
  sns.heatmap(solar_map, ax=ax2, cmap='YlOrRd', cbar_kws={'label': 'Solar Power Potential'})
270
  ax2.set_title('Solar Power Potential Map')
271
 
 
276
  print(f"Error in prediction: {str(e)}")
277
  raise e
278
 
279
+ def load_examples_from_directory(base_dir):
280
+ """ํด๋”์—์„œ ์˜ˆ์ œ ๋ฐ์ดํ„ฐ ๋กœ๋“œ"""
281
+ examples = []
282
+ sample_dirs = sorted(glob.glob(os.path.join(base_dir, "sample_*")))
283
+
284
+ for sample_dir in sample_dirs:
285
+ try:
286
+ # ํŒŒ์ผ ๊ฒฝ๋กœ ๊ตฌ์„ฑ
287
+ rgb_path = os.path.join(sample_dir, "satellite", "sentinel2_rgb_2023-07-15_to_2023-09-01.png")
288
+ ndvi_path = os.path.join(sample_dir, "satellite", "sentinel2_ndvi_2023-07-15_to_2023-09-01.png")
289
+ terrain_path = os.path.join(sample_dir, "terrain", "terrain_map.png")
290
+ elevation_path = os.path.join(sample_dir, "terrain", "elevation_data.npy")
291
+ weather_path = os.path.join(sample_dir, "weather", "weather_data.csv")
292
+
293
+ # ๊ธฐ์ƒ ๋ฐ์ดํ„ฐ ์ฝ๊ธฐ
294
+ weather_data = pd.read_csv(weather_path)
295
+ wind_speed = weather_data['wind_speed'].mean()
296
+ wind_direction = weather_data['wind_direction'].mean()
297
+ temperature = weather_data['temperature'].mean()
298
+ humidity = weather_data['humidity'].mean()
299
+
300
+ # ์˜ˆ์ œ ๋ฆฌ์ŠคํŠธ์— ์ถ”๊ฐ€
301
+ examples.append([
302
+ rgb_path,
303
+ ndvi_path,
304
+ terrain_path,
305
+ elevation_path,
306
+ float(wind_speed),
307
+ float(wind_direction),
308
+ float(temperature),
309
+ float(humidity)
310
+ ])
311
+ except Exception as e:
312
+ print(f"Error loading example from {sample_dir}: {str(e)}")
313
+ continue
314
+
315
+ return examples
316
+
317
  def create_gradio_interface():
318
  predictor = ClimatePredictor('best_model.pth')
319
 
320
+ def predict_and_visualize(rgb_path, ndvi_path, terrain_path, elevation_path,
321
  wind_speed, wind_direction, temperature, humidity):
322
+ # ์ด๋ฏธ์ง€ ๋กœ๋“œ
323
+ rgb_image = np.array(Image.open(rgb_path))
324
+ ndvi_image = np.array(Image.open(ndvi_path))
325
+ terrain_image = np.array(Image.open(terrain_path))
326
+ elevation_data = np.load(elevation_path)
327
 
328
+ # ์˜ˆ์ธก ๋ฐ ์‹œ๊ฐํ™”
329
  result = predictor.predict_from_inputs(
330
  rgb_image, ndvi_image, terrain_image, elevation_data,
331
  wind_speed, wind_direction, temperature, humidity
332
  )
333
  return result
334
 
335
+ # ์˜ˆ์ œ ๋ฐ์ดํ„ฐ ๋กœ๋“œ
336
+ examples = load_examples_from_directory("filtered_climate_data")
337
+ print(f"Loaded {len(examples)} examples")
338
+
339
  interface = gr.Interface(
340
  fn=predict_and_visualize,
341
  inputs=[
342
+ gr.Image(label="RGB Satellite Image", type="filepath"),
343
+ gr.Image(label="NDVI Image", type="filepath"),
344
+ gr.Image(label="Terrain Map", type="filepath"),
345
  gr.File(label="Elevation Data (NPY file)"),
346
  gr.Number(label="Wind Speed (m/s)", value=5.0),
347
  gr.Number(label="Wind Direction (degrees)", value=180.0),
 
351
  outputs=gr.Plot(label="Prediction Results"),
352
  title="Renewable Energy Potential Predictor",
353
  description="""Upload satellite imagery and environmental data to predict wind and solar power potential.
354
+ You can also try various examples from our dataset using the Examples section below.""",
355
+ examples=examples,
356
+ cache_examples=True
 
 
 
 
 
 
 
357
  )
358
  return interface
359
 
 
360
  if __name__ == "__main__":
361
  interface = create_gradio_interface()
362
  interface.launch()