HoeioUser commited on
Commit
11aa294
·
verified ·
1 Parent(s): 7a401a1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +104 -72
app.py CHANGED
@@ -1,7 +1,7 @@
1
  import gradio as gr
2
  import torch
3
  import numpy as np
4
- from PIL import Image
5
  import torchvision.transforms as transforms
6
  import pandas as pd
7
  import os
@@ -11,11 +11,6 @@ import torch.nn as nn
11
  import torch.nn.functional as F
12
  from torch.utils.data import Dataset, DataLoader
13
  from torch.optim.lr_scheduler import ReduceLROnPlateau
14
- import numpy as np
15
- import pandas as pd
16
- from PIL import Image
17
- import os
18
- import torchvision.transforms as transforms
19
  from sklearn.model_selection import train_test_split
20
  import glob
21
 
@@ -24,11 +19,8 @@ class ClimateNet(nn.Module):
24
  super(ClimateNet, self).__init__()
25
  self.input_size = input_size
26
  self.output_size = output_size
27
-
28
- # Feature map sizes after two max pooling layers
29
  self.feature_size = (input_size[0] // 4, input_size[1] // 4)
30
 
31
- # Improved RGB Encoder with residual connections
32
  self.rgb_encoder = nn.Sequential(
33
  nn.Conv2d(3, 64, kernel_size=3, padding=1),
34
  nn.BatchNorm2d(64),
@@ -49,7 +41,6 @@ class ClimateNet(nn.Module):
49
  nn.Dropout2d(0.2)
50
  )
51
 
52
- # Improved NDVI Encoder
53
  self.ndvi_encoder = nn.Sequential(
54
  nn.Conv2d(1, 64, kernel_size=3, padding=1),
55
  nn.BatchNorm2d(64),
@@ -67,7 +58,6 @@ class ClimateNet(nn.Module):
67
  nn.Dropout2d(0.2)
68
  )
69
 
70
- # Improved Terrain Encoder
71
  self.terrain_encoder = nn.Sequential(
72
  nn.Conv2d(1, 64, kernel_size=3, padding=1),
73
  nn.BatchNorm2d(64),
@@ -85,7 +75,6 @@ class ClimateNet(nn.Module):
85
  nn.Dropout2d(0.2)
86
  )
87
 
88
- # Improved Weather Encoder with deeper architecture
89
  self.weather_encoder = nn.Sequential(
90
  nn.Linear(4, 64),
91
  nn.ReLU(),
@@ -96,7 +85,6 @@ class ClimateNet(nn.Module):
96
  nn.Linear(128, 128)
97
  )
98
 
99
- # Improved Feature Fusion
100
  self.fusion = nn.Sequential(
101
  nn.Conv2d(512, 512, kernel_size=1),
102
  nn.BatchNorm2d(512),
@@ -107,7 +95,6 @@ class ClimateNet(nn.Module):
107
  nn.ReLU()
108
  )
109
 
110
- # Improved Decoders with skip connections
111
  self.wind_decoder = nn.Sequential(
112
  nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),
113
  nn.BatchNorm2d(256),
@@ -143,18 +130,15 @@ class ClimateNet(nn.Module):
143
  def forward(self, x):
144
  batch_size = x['rgb'].size(0)
145
 
146
- # Resize all inputs to input_size
147
  rgb_input = F.interpolate(x['rgb'], size=self.input_size, mode='bilinear', align_corners=False)
148
  ndvi_input = F.interpolate(x['ndvi'], size=self.input_size, mode='bilinear', align_corners=False)
149
  terrain_input = F.interpolate(x['terrain'], size=self.input_size, mode='bilinear', align_corners=False)
150
 
151
- # Extract features
152
- rgb_features = self.rgb_encoder(rgb_input) # [B, 128, H/4, W/4]
153
- ndvi_features = self.ndvi_encoder(ndvi_input) # [B, 128, H/4, W/4]
154
- terrain_features = self.terrain_encoder(terrain_input) # [B, 128, H/4, W/4]
155
 
156
- # Process weather features and expand to match feature map size
157
- weather_features = self.weather_encoder(x['weather_features']) # [B, 128]
158
  weather_features = weather_features.view(batch_size, 128, 1, 1)
159
  weather_features = F.interpolate(
160
  weather_features,
@@ -162,7 +146,6 @@ class ClimateNet(nn.Module):
162
  mode='nearest'
163
  )
164
 
165
- # Combine features
166
  combined_features = torch.cat([
167
  rgb_features,
168
  ndvi_features,
@@ -170,10 +153,8 @@ class ClimateNet(nn.Module):
170
  weather_features
171
  ], dim=1)
172
 
173
- # Apply fusion
174
  fused_features = self.fusion(combined_features)
175
 
176
- # Generate predictions and resize to output_size
177
  wind_heatmap = self.wind_decoder(fused_features)
178
  solar_heatmap = self.solar_decoder(fused_features)
179
 
@@ -191,7 +172,6 @@ class ClimatePredictor:
191
 
192
  print(f"Using device: {self.device}")
193
 
194
- # Load model
195
  self.model = ClimateNet(input_size=(256, 256), output_size=(64, 64)).to(self.device)
196
  checkpoint = torch.load(model_path, map_location=self.device)
197
 
@@ -214,15 +194,44 @@ class ClimatePredictor:
214
  ])
215
 
216
  def convert_to_single_channel(self, image_array):
217
- """RGB 이미지를 단일 채널로 변환"""
218
  if len(image_array.shape) == 3:
219
  return np.dot(image_array[...,:3], [0.2989, 0.5870, 0.1140])
220
  return image_array
221
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
222
  def predict_from_inputs(self, rgb_image, ndvi_image, terrain_image,
223
  elevation_data, wind_speed, wind_direction,
224
  temperature, humidity):
225
- """Gradio 인터페이스용 예측 함수"""
226
  try:
227
  # RGB 이미지 전처리
228
  rgb_tensor = self.rgb_transform(Image.fromarray(rgb_image)).unsqueeze(0)
@@ -262,17 +271,37 @@ class ClimatePredictor:
262
  solar_map = solar_pred.cpu().numpy()[0, 0]
263
 
264
  # 결과 시각화
265
- fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
266
 
267
- sns.heatmap(wind_map, ax=ax1, cmap='YlOrRd', cbar_kws={'label': 'Wind Power Potential'})
268
- ax1.set_title('Wind Power Potential Map')
 
 
 
 
269
 
270
- sns.heatmap(solar_map, ax=ax2, cmap='YlOrRd', cbar_kws={'label': 'Solar Power Potential'})
271
- ax2.set_title('Solar Power Potential Map')
 
 
 
272
 
273
- plt.tight_layout()
 
 
 
 
 
 
 
 
 
 
 
274
 
 
275
  return fig
 
276
  except Exception as e:
277
  print(f"Error in prediction: {str(e)}")
278
  raise e
@@ -284,21 +313,18 @@ def load_examples_from_directory(base_dir):
284
 
285
  for sample_dir in sample_dirs:
286
  try:
287
- # 파일 경로 구성
288
  rgb_path = os.path.join(sample_dir, "satellite", "sentinel2_rgb_2023-07-15_to_2023-09-01.png")
289
  ndvi_path = os.path.join(sample_dir, "satellite", "sentinel2_ndvi_2023-07-15_to_2023-09-01.png")
290
  terrain_path = os.path.join(sample_dir, "terrain", "terrain_map.png")
291
  elevation_path = os.path.join(sample_dir, "terrain", "elevation_data.npy")
292
  weather_path = os.path.join(sample_dir, "weather", "weather_data.csv")
293
 
294
- # 기상 데이터 읽기
295
  weather_data = pd.read_csv(weather_path)
296
  wind_speed = weather_data['wind_speed'].mean()
297
  wind_direction = weather_data['wind_direction'].mean()
298
  temperature = weather_data['temperature'].mean()
299
  humidity = weather_data['humidity'].mean()
300
 
301
- # 예제 리스트에 추가
302
  examples.append([
303
  rgb_path,
304
  ndvi_path,
@@ -318,46 +344,52 @@ def load_examples_from_directory(base_dir):
318
  def create_gradio_interface():
319
  predictor = ClimatePredictor('best_model.pth')
320
 
321
- def predict_and_visualize(rgb_path, ndvi_path, terrain_path, elevation_path,
322
- wind_speed, wind_direction, temperature, humidity):
323
- # 이미지 로드
324
- rgb_image = np.array(Image.open(rgb_path))
325
- ndvi_image = np.array(Image.open(ndvi_path))
326
- terrain_image = np.array(Image.open(terrain_path))
327
- elevation_data = np.load(elevation_path)
328
 
329
- # 예측 및 시각화
330
- result = predictor.predict_from_inputs(
331
- rgb_image, ndvi_image, terrain_image, elevation_data,
332
- wind_speed, wind_direction, temperature, humidity
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
333
  )
334
- return result
335
-
336
- # 예제 데이터 로드
337
- examples = load_examples_from_directory("filtered_climate_data")
338
- print(f"Loaded {len(examples)} examples")
339
 
340
- interface = gr.Interface(
341
- fn=predict_and_visualize,
342
- inputs=[
343
- gr.Image(label="RGB Satellite Image", type="filepath"),
344
- gr.Image(label="NDVI Image", type="filepath"),
345
- gr.Image(label="Terrain Map", type="filepath"),
346
- gr.File(label="Elevation Data (NPY file)"),
347
- gr.Number(label="Wind Speed (m/s)", value=5.0),
348
- gr.Number(label="Wind Direction (degrees)", value=180.0),
349
- gr.Number(label="Temperature (°C)", value=25.0),
350
- gr.Number(label="Humidity (%)", value=60.0)
351
- ],
352
- outputs=gr.Plot(label="Prediction Results"),
353
- title="Renewable Energy Potential Predictor",
354
- description="""Upload satellite imagery and environmental data to predict wind and solar power potential.
355
- You can also try various examples from our dataset using the Examples section below.""",
356
- examples=examples,
357
- cache_examples=True
358
- )
359
  return interface
360
 
361
  if __name__ == "__main__":
362
  interface = create_gradio_interface()
363
- interface.launch()
 
1
  import gradio as gr
2
  import torch
3
  import numpy as np
4
+ from PIL import Image, ImageDraw
5
  import torchvision.transforms as transforms
6
  import pandas as pd
7
  import os
 
11
  import torch.nn.functional as F
12
  from torch.utils.data import Dataset, DataLoader
13
  from torch.optim.lr_scheduler import ReduceLROnPlateau
 
 
 
 
 
14
  from sklearn.model_selection import train_test_split
15
  import glob
16
 
 
19
  super(ClimateNet, self).__init__()
20
  self.input_size = input_size
21
  self.output_size = output_size
 
 
22
  self.feature_size = (input_size[0] // 4, input_size[1] // 4)
23
 
 
24
  self.rgb_encoder = nn.Sequential(
25
  nn.Conv2d(3, 64, kernel_size=3, padding=1),
26
  nn.BatchNorm2d(64),
 
41
  nn.Dropout2d(0.2)
42
  )
43
 
 
44
  self.ndvi_encoder = nn.Sequential(
45
  nn.Conv2d(1, 64, kernel_size=3, padding=1),
46
  nn.BatchNorm2d(64),
 
58
  nn.Dropout2d(0.2)
59
  )
60
 
 
61
  self.terrain_encoder = nn.Sequential(
62
  nn.Conv2d(1, 64, kernel_size=3, padding=1),
63
  nn.BatchNorm2d(64),
 
75
  nn.Dropout2d(0.2)
76
  )
77
 
 
78
  self.weather_encoder = nn.Sequential(
79
  nn.Linear(4, 64),
80
  nn.ReLU(),
 
85
  nn.Linear(128, 128)
86
  )
87
 
 
88
  self.fusion = nn.Sequential(
89
  nn.Conv2d(512, 512, kernel_size=1),
90
  nn.BatchNorm2d(512),
 
95
  nn.ReLU()
96
  )
97
 
 
98
  self.wind_decoder = nn.Sequential(
99
  nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),
100
  nn.BatchNorm2d(256),
 
130
  def forward(self, x):
131
  batch_size = x['rgb'].size(0)
132
 
 
133
  rgb_input = F.interpolate(x['rgb'], size=self.input_size, mode='bilinear', align_corners=False)
134
  ndvi_input = F.interpolate(x['ndvi'], size=self.input_size, mode='bilinear', align_corners=False)
135
  terrain_input = F.interpolate(x['terrain'], size=self.input_size, mode='bilinear', align_corners=False)
136
 
137
+ rgb_features = self.rgb_encoder(rgb_input)
138
+ ndvi_features = self.ndvi_encoder(ndvi_input)
139
+ terrain_features = self.terrain_encoder(terrain_input)
 
140
 
141
+ weather_features = self.weather_encoder(x['weather_features'])
 
142
  weather_features = weather_features.view(batch_size, 128, 1, 1)
143
  weather_features = F.interpolate(
144
  weather_features,
 
146
  mode='nearest'
147
  )
148
 
 
149
  combined_features = torch.cat([
150
  rgb_features,
151
  ndvi_features,
 
153
  weather_features
154
  ], dim=1)
155
 
 
156
  fused_features = self.fusion(combined_features)
157
 
 
158
  wind_heatmap = self.wind_decoder(fused_features)
159
  solar_heatmap = self.solar_decoder(fused_features)
160
 
 
172
 
173
  print(f"Using device: {self.device}")
174
 
 
175
  self.model = ClimateNet(input_size=(256, 256), output_size=(64, 64)).to(self.device)
176
  checkpoint = torch.load(model_path, map_location=self.device)
177
 
 
194
  ])
195
 
196
  def convert_to_single_channel(self, image_array):
 
197
  if len(image_array.shape) == 3:
198
  return np.dot(image_array[...,:3], [0.2989, 0.5870, 0.1140])
199
  return image_array
200
 
201
+ def overlay_emojis(self, rgb_image, wind_map, solar_map, threshold=0.7):
202
+ """히트맵을 기반으로 이모지 오버레이"""
203
+ img = Image.fromarray(rgb_image)
204
+ draw = ImageDraw.Draw(img)
205
+
206
+ h, w = rgb_image.shape[:2]
207
+ wind_map_resized = Image.fromarray((wind_map * 255).astype(np.uint8)).resize((w, h))
208
+ solar_map_resized = Image.fromarray((solar_map * 255).astype(np.uint8)).resize((w, h))
209
+
210
+ wind_map_np = np.array(wind_map_resized) / 255.0
211
+ solar_map_np = np.array(solar_map_resized) / 255.0
212
+
213
+ emoji_size = min(w, h) // 20
214
+ grid_step = emoji_size * 2
215
+
216
+ for y in range(0, h - emoji_size, grid_step):
217
+ for x in range(0, w - emoji_size, grid_step):
218
+ wind_val = wind_map_np[y:y+emoji_size, x:x+emoji_size].mean()
219
+ solar_val = solar_map_np[y:y+emoji_size, x:x+emoji_size].mean()
220
+
221
+ text = ""
222
+ if wind_val > threshold:
223
+ text += "💨"
224
+ if solar_val > threshold:
225
+ text += "☀️"
226
+
227
+ if text:
228
+ draw.text((x, y), text, fill="white")
229
+
230
+ return np.array(img)
231
+
232
  def predict_from_inputs(self, rgb_image, ndvi_image, terrain_image,
233
  elevation_data, wind_speed, wind_direction,
234
  temperature, humidity):
 
235
  try:
236
  # RGB 이미지 전처리
237
  rgb_tensor = self.rgb_transform(Image.fromarray(rgb_image)).unsqueeze(0)
 
271
  solar_map = solar_pred.cpu().numpy()[0, 0]
272
 
273
  # 결과 시각화
274
+ fig = plt.figure(figsize=(20, 10))
275
 
276
+ # 1. 원본 이미지와 이모지 오버레이
277
+ ax1 = plt.subplot(2, 2, 1)
278
+ overlay_img = self.overlay_emojis(rgb_image, wind_map, solar_map)
279
+ ax1.imshow(overlay_img)
280
+ ax1.set_title('Predictions Overlay on RGB Image')
281
+ ax1.axis('off')
282
 
283
+ # 2. 풍력 발전 잠재량 히트맵
284
+ ax2 = plt.subplot(2, 2, 2)
285
+ wind_heatmap = sns.heatmap(wind_map, ax=ax2, cmap='YlOrRd',
286
+ cbar_kws={'label': 'Wind Power Potential'})
287
+ ax2.set_title('Wind Power Potential Map')
288
 
289
+ # 3. 태양광 발전 잠재량 히트맵
290
+ ax3 = plt.subplot(2, 2, 3)
291
+ solar_heatmap = sns.heatmap(solar_map, ax=ax3, cmap='YlOrRd',
292
+ cbar_kws={'label': 'Solar Power Potential'})
293
+ ax3.set_title('Solar Power Potential Map')
294
+
295
+ # 4. 컴바인드 히트맵
296
+ ax4 = plt.subplot(2, 2, 4)
297
+ combined_map = np.stack([solar_map, wind_map, np.zeros_like(wind_map)], axis=-1)
298
+ ax4.imshow(combined_map)
299
+ ax4.set_title('Combined Potential Map (Red: Solar, Green: Wind)')
300
+ ax4.axis('off')
301
 
302
+ plt.tight_layout()
303
  return fig
304
+
305
  except Exception as e:
306
  print(f"Error in prediction: {str(e)}")
307
  raise e
 
313
 
314
  for sample_dir in sample_dirs:
315
  try:
 
316
  rgb_path = os.path.join(sample_dir, "satellite", "sentinel2_rgb_2023-07-15_to_2023-09-01.png")
317
  ndvi_path = os.path.join(sample_dir, "satellite", "sentinel2_ndvi_2023-07-15_to_2023-09-01.png")
318
  terrain_path = os.path.join(sample_dir, "terrain", "terrain_map.png")
319
  elevation_path = os.path.join(sample_dir, "terrain", "elevation_data.npy")
320
  weather_path = os.path.join(sample_dir, "weather", "weather_data.csv")
321
 
 
322
  weather_data = pd.read_csv(weather_path)
323
  wind_speed = weather_data['wind_speed'].mean()
324
  wind_direction = weather_data['wind_direction'].mean()
325
  temperature = weather_data['temperature'].mean()
326
  humidity = weather_data['humidity'].mean()
327
 
 
328
  examples.append([
329
  rgb_path,
330
  ndvi_path,
 
344
  def create_gradio_interface():
345
  predictor = ClimatePredictor('best_model.pth')
346
 
347
+ with gr.Blocks() as interface:
348
+ gr.Markdown("# Renewable Energy Potential Predictor")
349
+ gr.Markdown("Upload satellite imagery and environmental data to predict wind and solar power potential.")
 
 
 
 
350
 
351
+ with gr.Row():
352
+ # 입력 섹션 (1/3 크기)
353
+ with gr.Column(scale=1):
354
+ rgb_input = gr.Image(label="RGB Satellite Image", type="numpy")
355
+ ndvi_input = gr.Image(label="NDVI Image", type="numpy")
356
+ terrain_input = gr.Image(label="Terrain Map", type="numpy")
357
+ elevation_input = gr.File(label="Elevation Data (NPY file)")
358
+
359
+ with gr.Row():
360
+ wind_speed = gr.Number(label="Wind Speed (m/s)", value=5.0)
361
+ wind_direction = gr.Number(label="Wind Direction (°)", value=180.0)
362
+
363
+ with gr.Row():
364
+ temperature = gr.Number(label="Temperature (°C)", value=25.0)
365
+ humidity = gr.Number(label="Humidity (%)", value=60.0)
366
+
367
+ predict_btn = gr.Button("Generate Predictions", variant="primary")
368
+
369
+ # 출력 섹션 (2/3 크기)
370
+ with gr.Column(scale=2):
371
+ output_plot = gr.Plot(label="Prediction Results")
372
+
373
+ # 예측 버튼 클릭 이벤트 연결
374
+ predict_btn.click(
375
+ fn=predictor.predict_from_inputs,
376
+ inputs=[rgb_input, ndvi_input, terrain_input, elevation_input,
377
+ wind_speed, wind_direction, temperature, humidity],
378
+ outputs=output_plot
379
+ )
380
+
381
+ # 예제 섹션
382
+ examples = load_examples_from_directory("filtered_climate_data")
383
+ gr.Examples(
384
+ examples=examples,
385
+ inputs=[rgb_input, ndvi_input, terrain_input, elevation_input,
386
+ wind_speed, wind_direction, temperature, humidity],
387
+ outputs=output_plot,
388
+ cache_examples=True
389
  )
 
 
 
 
 
390
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
391
  return interface
392
 
393
  if __name__ == "__main__":
394
  interface = create_gradio_interface()
395
+ interface.launch()