Update app.py
Browse files
app.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
import gradio as gr
|
2 |
import torch
|
3 |
import numpy as np
|
4 |
-
from PIL import Image
|
5 |
import torchvision.transforms as transforms
|
6 |
import pandas as pd
|
7 |
import os
|
@@ -11,11 +11,6 @@ import torch.nn as nn
|
|
11 |
import torch.nn.functional as F
|
12 |
from torch.utils.data import Dataset, DataLoader
|
13 |
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
14 |
-
import numpy as np
|
15 |
-
import pandas as pd
|
16 |
-
from PIL import Image
|
17 |
-
import os
|
18 |
-
import torchvision.transforms as transforms
|
19 |
from sklearn.model_selection import train_test_split
|
20 |
import glob
|
21 |
|
@@ -24,11 +19,8 @@ class ClimateNet(nn.Module):
|
|
24 |
super(ClimateNet, self).__init__()
|
25 |
self.input_size = input_size
|
26 |
self.output_size = output_size
|
27 |
-
|
28 |
-
# Feature map sizes after two max pooling layers
|
29 |
self.feature_size = (input_size[0] // 4, input_size[1] // 4)
|
30 |
|
31 |
-
# Improved RGB Encoder with residual connections
|
32 |
self.rgb_encoder = nn.Sequential(
|
33 |
nn.Conv2d(3, 64, kernel_size=3, padding=1),
|
34 |
nn.BatchNorm2d(64),
|
@@ -49,7 +41,6 @@ class ClimateNet(nn.Module):
|
|
49 |
nn.Dropout2d(0.2)
|
50 |
)
|
51 |
|
52 |
-
# Improved NDVI Encoder
|
53 |
self.ndvi_encoder = nn.Sequential(
|
54 |
nn.Conv2d(1, 64, kernel_size=3, padding=1),
|
55 |
nn.BatchNorm2d(64),
|
@@ -67,7 +58,6 @@ class ClimateNet(nn.Module):
|
|
67 |
nn.Dropout2d(0.2)
|
68 |
)
|
69 |
|
70 |
-
# Improved Terrain Encoder
|
71 |
self.terrain_encoder = nn.Sequential(
|
72 |
nn.Conv2d(1, 64, kernel_size=3, padding=1),
|
73 |
nn.BatchNorm2d(64),
|
@@ -85,7 +75,6 @@ class ClimateNet(nn.Module):
|
|
85 |
nn.Dropout2d(0.2)
|
86 |
)
|
87 |
|
88 |
-
# Improved Weather Encoder with deeper architecture
|
89 |
self.weather_encoder = nn.Sequential(
|
90 |
nn.Linear(4, 64),
|
91 |
nn.ReLU(),
|
@@ -96,7 +85,6 @@ class ClimateNet(nn.Module):
|
|
96 |
nn.Linear(128, 128)
|
97 |
)
|
98 |
|
99 |
-
# Improved Feature Fusion
|
100 |
self.fusion = nn.Sequential(
|
101 |
nn.Conv2d(512, 512, kernel_size=1),
|
102 |
nn.BatchNorm2d(512),
|
@@ -107,7 +95,6 @@ class ClimateNet(nn.Module):
|
|
107 |
nn.ReLU()
|
108 |
)
|
109 |
|
110 |
-
# Improved Decoders with skip connections
|
111 |
self.wind_decoder = nn.Sequential(
|
112 |
nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),
|
113 |
nn.BatchNorm2d(256),
|
@@ -143,18 +130,15 @@ class ClimateNet(nn.Module):
|
|
143 |
def forward(self, x):
|
144 |
batch_size = x['rgb'].size(0)
|
145 |
|
146 |
-
# Resize all inputs to input_size
|
147 |
rgb_input = F.interpolate(x['rgb'], size=self.input_size, mode='bilinear', align_corners=False)
|
148 |
ndvi_input = F.interpolate(x['ndvi'], size=self.input_size, mode='bilinear', align_corners=False)
|
149 |
terrain_input = F.interpolate(x['terrain'], size=self.input_size, mode='bilinear', align_corners=False)
|
150 |
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
terrain_features = self.terrain_encoder(terrain_input) # [B, 128, H/4, W/4]
|
155 |
|
156 |
-
|
157 |
-
weather_features = self.weather_encoder(x['weather_features']) # [B, 128]
|
158 |
weather_features = weather_features.view(batch_size, 128, 1, 1)
|
159 |
weather_features = F.interpolate(
|
160 |
weather_features,
|
@@ -162,7 +146,6 @@ class ClimateNet(nn.Module):
|
|
162 |
mode='nearest'
|
163 |
)
|
164 |
|
165 |
-
# Combine features
|
166 |
combined_features = torch.cat([
|
167 |
rgb_features,
|
168 |
ndvi_features,
|
@@ -170,10 +153,8 @@ class ClimateNet(nn.Module):
|
|
170 |
weather_features
|
171 |
], dim=1)
|
172 |
|
173 |
-
# Apply fusion
|
174 |
fused_features = self.fusion(combined_features)
|
175 |
|
176 |
-
# Generate predictions and resize to output_size
|
177 |
wind_heatmap = self.wind_decoder(fused_features)
|
178 |
solar_heatmap = self.solar_decoder(fused_features)
|
179 |
|
@@ -191,7 +172,6 @@ class ClimatePredictor:
|
|
191 |
|
192 |
print(f"Using device: {self.device}")
|
193 |
|
194 |
-
# Load model
|
195 |
self.model = ClimateNet(input_size=(256, 256), output_size=(64, 64)).to(self.device)
|
196 |
checkpoint = torch.load(model_path, map_location=self.device)
|
197 |
|
@@ -214,15 +194,44 @@ class ClimatePredictor:
|
|
214 |
])
|
215 |
|
216 |
def convert_to_single_channel(self, image_array):
|
217 |
-
"""RGB 이미지를 단일 채널로 변환"""
|
218 |
if len(image_array.shape) == 3:
|
219 |
return np.dot(image_array[...,:3], [0.2989, 0.5870, 0.1140])
|
220 |
return image_array
|
221 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
222 |
def predict_from_inputs(self, rgb_image, ndvi_image, terrain_image,
|
223 |
elevation_data, wind_speed, wind_direction,
|
224 |
temperature, humidity):
|
225 |
-
"""Gradio 인터페이스용 예측 함수"""
|
226 |
try:
|
227 |
# RGB 이미지 전처리
|
228 |
rgb_tensor = self.rgb_transform(Image.fromarray(rgb_image)).unsqueeze(0)
|
@@ -262,17 +271,37 @@ class ClimatePredictor:
|
|
262 |
solar_map = solar_pred.cpu().numpy()[0, 0]
|
263 |
|
264 |
# 결과 시각화
|
265 |
-
fig
|
266 |
|
267 |
-
|
268 |
-
ax1.
|
|
|
|
|
|
|
|
|
269 |
|
270 |
-
|
271 |
-
ax2.
|
|
|
|
|
|
|
272 |
|
273 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
274 |
|
|
|
275 |
return fig
|
|
|
276 |
except Exception as e:
|
277 |
print(f"Error in prediction: {str(e)}")
|
278 |
raise e
|
@@ -284,21 +313,18 @@ def load_examples_from_directory(base_dir):
|
|
284 |
|
285 |
for sample_dir in sample_dirs:
|
286 |
try:
|
287 |
-
# 파일 경로 구성
|
288 |
rgb_path = os.path.join(sample_dir, "satellite", "sentinel2_rgb_2023-07-15_to_2023-09-01.png")
|
289 |
ndvi_path = os.path.join(sample_dir, "satellite", "sentinel2_ndvi_2023-07-15_to_2023-09-01.png")
|
290 |
terrain_path = os.path.join(sample_dir, "terrain", "terrain_map.png")
|
291 |
elevation_path = os.path.join(sample_dir, "terrain", "elevation_data.npy")
|
292 |
weather_path = os.path.join(sample_dir, "weather", "weather_data.csv")
|
293 |
|
294 |
-
# 기상 데이터 읽기
|
295 |
weather_data = pd.read_csv(weather_path)
|
296 |
wind_speed = weather_data['wind_speed'].mean()
|
297 |
wind_direction = weather_data['wind_direction'].mean()
|
298 |
temperature = weather_data['temperature'].mean()
|
299 |
humidity = weather_data['humidity'].mean()
|
300 |
|
301 |
-
# 예제 리스트에 추가
|
302 |
examples.append([
|
303 |
rgb_path,
|
304 |
ndvi_path,
|
@@ -318,46 +344,52 @@ def load_examples_from_directory(base_dir):
|
|
318 |
def create_gradio_interface():
|
319 |
predictor = ClimatePredictor('best_model.pth')
|
320 |
|
321 |
-
|
322 |
-
|
323 |
-
|
324 |
-
rgb_image = np.array(Image.open(rgb_path))
|
325 |
-
ndvi_image = np.array(Image.open(ndvi_path))
|
326 |
-
terrain_image = np.array(Image.open(terrain_path))
|
327 |
-
elevation_data = np.load(elevation_path)
|
328 |
|
329 |
-
|
330 |
-
|
331 |
-
|
332 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
333 |
)
|
334 |
-
return result
|
335 |
-
|
336 |
-
# 예제 데이터 로드
|
337 |
-
examples = load_examples_from_directory("filtered_climate_data")
|
338 |
-
print(f"Loaded {len(examples)} examples")
|
339 |
|
340 |
-
interface = gr.Interface(
|
341 |
-
fn=predict_and_visualize,
|
342 |
-
inputs=[
|
343 |
-
gr.Image(label="RGB Satellite Image", type="filepath"),
|
344 |
-
gr.Image(label="NDVI Image", type="filepath"),
|
345 |
-
gr.Image(label="Terrain Map", type="filepath"),
|
346 |
-
gr.File(label="Elevation Data (NPY file)"),
|
347 |
-
gr.Number(label="Wind Speed (m/s)", value=5.0),
|
348 |
-
gr.Number(label="Wind Direction (degrees)", value=180.0),
|
349 |
-
gr.Number(label="Temperature (°C)", value=25.0),
|
350 |
-
gr.Number(label="Humidity (%)", value=60.0)
|
351 |
-
],
|
352 |
-
outputs=gr.Plot(label="Prediction Results"),
|
353 |
-
title="Renewable Energy Potential Predictor",
|
354 |
-
description="""Upload satellite imagery and environmental data to predict wind and solar power potential.
|
355 |
-
You can also try various examples from our dataset using the Examples section below.""",
|
356 |
-
examples=examples,
|
357 |
-
cache_examples=True
|
358 |
-
)
|
359 |
return interface
|
360 |
|
361 |
if __name__ == "__main__":
|
362 |
interface = create_gradio_interface()
|
363 |
-
interface.launch()
|
|
|
1 |
import gradio as gr
|
2 |
import torch
|
3 |
import numpy as np
|
4 |
+
from PIL import Image, ImageDraw
|
5 |
import torchvision.transforms as transforms
|
6 |
import pandas as pd
|
7 |
import os
|
|
|
11 |
import torch.nn.functional as F
|
12 |
from torch.utils.data import Dataset, DataLoader
|
13 |
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
|
|
|
|
|
|
|
|
|
|
14 |
from sklearn.model_selection import train_test_split
|
15 |
import glob
|
16 |
|
|
|
19 |
super(ClimateNet, self).__init__()
|
20 |
self.input_size = input_size
|
21 |
self.output_size = output_size
|
|
|
|
|
22 |
self.feature_size = (input_size[0] // 4, input_size[1] // 4)
|
23 |
|
|
|
24 |
self.rgb_encoder = nn.Sequential(
|
25 |
nn.Conv2d(3, 64, kernel_size=3, padding=1),
|
26 |
nn.BatchNorm2d(64),
|
|
|
41 |
nn.Dropout2d(0.2)
|
42 |
)
|
43 |
|
|
|
44 |
self.ndvi_encoder = nn.Sequential(
|
45 |
nn.Conv2d(1, 64, kernel_size=3, padding=1),
|
46 |
nn.BatchNorm2d(64),
|
|
|
58 |
nn.Dropout2d(0.2)
|
59 |
)
|
60 |
|
|
|
61 |
self.terrain_encoder = nn.Sequential(
|
62 |
nn.Conv2d(1, 64, kernel_size=3, padding=1),
|
63 |
nn.BatchNorm2d(64),
|
|
|
75 |
nn.Dropout2d(0.2)
|
76 |
)
|
77 |
|
|
|
78 |
self.weather_encoder = nn.Sequential(
|
79 |
nn.Linear(4, 64),
|
80 |
nn.ReLU(),
|
|
|
85 |
nn.Linear(128, 128)
|
86 |
)
|
87 |
|
|
|
88 |
self.fusion = nn.Sequential(
|
89 |
nn.Conv2d(512, 512, kernel_size=1),
|
90 |
nn.BatchNorm2d(512),
|
|
|
95 |
nn.ReLU()
|
96 |
)
|
97 |
|
|
|
98 |
self.wind_decoder = nn.Sequential(
|
99 |
nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),
|
100 |
nn.BatchNorm2d(256),
|
|
|
130 |
def forward(self, x):
|
131 |
batch_size = x['rgb'].size(0)
|
132 |
|
|
|
133 |
rgb_input = F.interpolate(x['rgb'], size=self.input_size, mode='bilinear', align_corners=False)
|
134 |
ndvi_input = F.interpolate(x['ndvi'], size=self.input_size, mode='bilinear', align_corners=False)
|
135 |
terrain_input = F.interpolate(x['terrain'], size=self.input_size, mode='bilinear', align_corners=False)
|
136 |
|
137 |
+
rgb_features = self.rgb_encoder(rgb_input)
|
138 |
+
ndvi_features = self.ndvi_encoder(ndvi_input)
|
139 |
+
terrain_features = self.terrain_encoder(terrain_input)
|
|
|
140 |
|
141 |
+
weather_features = self.weather_encoder(x['weather_features'])
|
|
|
142 |
weather_features = weather_features.view(batch_size, 128, 1, 1)
|
143 |
weather_features = F.interpolate(
|
144 |
weather_features,
|
|
|
146 |
mode='nearest'
|
147 |
)
|
148 |
|
|
|
149 |
combined_features = torch.cat([
|
150 |
rgb_features,
|
151 |
ndvi_features,
|
|
|
153 |
weather_features
|
154 |
], dim=1)
|
155 |
|
|
|
156 |
fused_features = self.fusion(combined_features)
|
157 |
|
|
|
158 |
wind_heatmap = self.wind_decoder(fused_features)
|
159 |
solar_heatmap = self.solar_decoder(fused_features)
|
160 |
|
|
|
172 |
|
173 |
print(f"Using device: {self.device}")
|
174 |
|
|
|
175 |
self.model = ClimateNet(input_size=(256, 256), output_size=(64, 64)).to(self.device)
|
176 |
checkpoint = torch.load(model_path, map_location=self.device)
|
177 |
|
|
|
194 |
])
|
195 |
|
196 |
def convert_to_single_channel(self, image_array):
|
|
|
197 |
if len(image_array.shape) == 3:
|
198 |
return np.dot(image_array[...,:3], [0.2989, 0.5870, 0.1140])
|
199 |
return image_array
|
200 |
|
201 |
+
def overlay_emojis(self, rgb_image, wind_map, solar_map, threshold=0.7):
|
202 |
+
"""히트맵을 기반으로 이모지 오버레이"""
|
203 |
+
img = Image.fromarray(rgb_image)
|
204 |
+
draw = ImageDraw.Draw(img)
|
205 |
+
|
206 |
+
h, w = rgb_image.shape[:2]
|
207 |
+
wind_map_resized = Image.fromarray((wind_map * 255).astype(np.uint8)).resize((w, h))
|
208 |
+
solar_map_resized = Image.fromarray((solar_map * 255).astype(np.uint8)).resize((w, h))
|
209 |
+
|
210 |
+
wind_map_np = np.array(wind_map_resized) / 255.0
|
211 |
+
solar_map_np = np.array(solar_map_resized) / 255.0
|
212 |
+
|
213 |
+
emoji_size = min(w, h) // 20
|
214 |
+
grid_step = emoji_size * 2
|
215 |
+
|
216 |
+
for y in range(0, h - emoji_size, grid_step):
|
217 |
+
for x in range(0, w - emoji_size, grid_step):
|
218 |
+
wind_val = wind_map_np[y:y+emoji_size, x:x+emoji_size].mean()
|
219 |
+
solar_val = solar_map_np[y:y+emoji_size, x:x+emoji_size].mean()
|
220 |
+
|
221 |
+
text = ""
|
222 |
+
if wind_val > threshold:
|
223 |
+
text += "💨"
|
224 |
+
if solar_val > threshold:
|
225 |
+
text += "☀️"
|
226 |
+
|
227 |
+
if text:
|
228 |
+
draw.text((x, y), text, fill="white")
|
229 |
+
|
230 |
+
return np.array(img)
|
231 |
+
|
232 |
def predict_from_inputs(self, rgb_image, ndvi_image, terrain_image,
|
233 |
elevation_data, wind_speed, wind_direction,
|
234 |
temperature, humidity):
|
|
|
235 |
try:
|
236 |
# RGB 이미지 전처리
|
237 |
rgb_tensor = self.rgb_transform(Image.fromarray(rgb_image)).unsqueeze(0)
|
|
|
271 |
solar_map = solar_pred.cpu().numpy()[0, 0]
|
272 |
|
273 |
# 결과 시각화
|
274 |
+
fig = plt.figure(figsize=(20, 10))
|
275 |
|
276 |
+
# 1. 원본 이미지와 이모지 오버레이
|
277 |
+
ax1 = plt.subplot(2, 2, 1)
|
278 |
+
overlay_img = self.overlay_emojis(rgb_image, wind_map, solar_map)
|
279 |
+
ax1.imshow(overlay_img)
|
280 |
+
ax1.set_title('Predictions Overlay on RGB Image')
|
281 |
+
ax1.axis('off')
|
282 |
|
283 |
+
# 2. 풍력 발전 잠재량 히트맵
|
284 |
+
ax2 = plt.subplot(2, 2, 2)
|
285 |
+
wind_heatmap = sns.heatmap(wind_map, ax=ax2, cmap='YlOrRd',
|
286 |
+
cbar_kws={'label': 'Wind Power Potential'})
|
287 |
+
ax2.set_title('Wind Power Potential Map')
|
288 |
|
289 |
+
# 3. 태양광 발전 잠재량 히트맵
|
290 |
+
ax3 = plt.subplot(2, 2, 3)
|
291 |
+
solar_heatmap = sns.heatmap(solar_map, ax=ax3, cmap='YlOrRd',
|
292 |
+
cbar_kws={'label': 'Solar Power Potential'})
|
293 |
+
ax3.set_title('Solar Power Potential Map')
|
294 |
+
|
295 |
+
# 4. 컴바인드 히트맵
|
296 |
+
ax4 = plt.subplot(2, 2, 4)
|
297 |
+
combined_map = np.stack([solar_map, wind_map, np.zeros_like(wind_map)], axis=-1)
|
298 |
+
ax4.imshow(combined_map)
|
299 |
+
ax4.set_title('Combined Potential Map (Red: Solar, Green: Wind)')
|
300 |
+
ax4.axis('off')
|
301 |
|
302 |
+
plt.tight_layout()
|
303 |
return fig
|
304 |
+
|
305 |
except Exception as e:
|
306 |
print(f"Error in prediction: {str(e)}")
|
307 |
raise e
|
|
|
313 |
|
314 |
for sample_dir in sample_dirs:
|
315 |
try:
|
|
|
316 |
rgb_path = os.path.join(sample_dir, "satellite", "sentinel2_rgb_2023-07-15_to_2023-09-01.png")
|
317 |
ndvi_path = os.path.join(sample_dir, "satellite", "sentinel2_ndvi_2023-07-15_to_2023-09-01.png")
|
318 |
terrain_path = os.path.join(sample_dir, "terrain", "terrain_map.png")
|
319 |
elevation_path = os.path.join(sample_dir, "terrain", "elevation_data.npy")
|
320 |
weather_path = os.path.join(sample_dir, "weather", "weather_data.csv")
|
321 |
|
|
|
322 |
weather_data = pd.read_csv(weather_path)
|
323 |
wind_speed = weather_data['wind_speed'].mean()
|
324 |
wind_direction = weather_data['wind_direction'].mean()
|
325 |
temperature = weather_data['temperature'].mean()
|
326 |
humidity = weather_data['humidity'].mean()
|
327 |
|
|
|
328 |
examples.append([
|
329 |
rgb_path,
|
330 |
ndvi_path,
|
|
|
344 |
def create_gradio_interface():
|
345 |
predictor = ClimatePredictor('best_model.pth')
|
346 |
|
347 |
+
with gr.Blocks() as interface:
|
348 |
+
gr.Markdown("# Renewable Energy Potential Predictor")
|
349 |
+
gr.Markdown("Upload satellite imagery and environmental data to predict wind and solar power potential.")
|
|
|
|
|
|
|
|
|
350 |
|
351 |
+
with gr.Row():
|
352 |
+
# 입력 섹션 (1/3 크기)
|
353 |
+
with gr.Column(scale=1):
|
354 |
+
rgb_input = gr.Image(label="RGB Satellite Image", type="numpy")
|
355 |
+
ndvi_input = gr.Image(label="NDVI Image", type="numpy")
|
356 |
+
terrain_input = gr.Image(label="Terrain Map", type="numpy")
|
357 |
+
elevation_input = gr.File(label="Elevation Data (NPY file)")
|
358 |
+
|
359 |
+
with gr.Row():
|
360 |
+
wind_speed = gr.Number(label="Wind Speed (m/s)", value=5.0)
|
361 |
+
wind_direction = gr.Number(label="Wind Direction (°)", value=180.0)
|
362 |
+
|
363 |
+
with gr.Row():
|
364 |
+
temperature = gr.Number(label="Temperature (°C)", value=25.0)
|
365 |
+
humidity = gr.Number(label="Humidity (%)", value=60.0)
|
366 |
+
|
367 |
+
predict_btn = gr.Button("Generate Predictions", variant="primary")
|
368 |
+
|
369 |
+
# 출력 섹션 (2/3 크기)
|
370 |
+
with gr.Column(scale=2):
|
371 |
+
output_plot = gr.Plot(label="Prediction Results")
|
372 |
+
|
373 |
+
# 예측 버튼 클릭 이벤트 연결
|
374 |
+
predict_btn.click(
|
375 |
+
fn=predictor.predict_from_inputs,
|
376 |
+
inputs=[rgb_input, ndvi_input, terrain_input, elevation_input,
|
377 |
+
wind_speed, wind_direction, temperature, humidity],
|
378 |
+
outputs=output_plot
|
379 |
+
)
|
380 |
+
|
381 |
+
# 예제 섹션
|
382 |
+
examples = load_examples_from_directory("filtered_climate_data")
|
383 |
+
gr.Examples(
|
384 |
+
examples=examples,
|
385 |
+
inputs=[rgb_input, ndvi_input, terrain_input, elevation_input,
|
386 |
+
wind_speed, wind_direction, temperature, humidity],
|
387 |
+
outputs=output_plot,
|
388 |
+
cache_examples=True
|
389 |
)
|
|
|
|
|
|
|
|
|
|
|
390 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
391 |
return interface
|
392 |
|
393 |
if __name__ == "__main__":
|
394 |
interface = create_gradio_interface()
|
395 |
+
interface.launch()
|