Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,310 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
from PIL import Image
|
5 |
+
import torchvision.transforms as transforms
|
6 |
+
import pandas as pd
|
7 |
+
import os
|
8 |
+
import matplotlib.pyplot as plt
|
9 |
+
import seaborn as sns
|
10 |
+
import torch.nn as nn
|
11 |
+
import torch.nn.functional as F
|
12 |
+
from torch.utils.data import Dataset, DataLoader
|
13 |
+
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
14 |
+
import numpy as np
|
15 |
+
import pandas as pd
|
16 |
+
from PIL import Image
|
17 |
+
import os
|
18 |
+
import torchvision.transforms as transforms
|
19 |
+
from sklearn.model_selection import train_test_split
|
20 |
+
|
21 |
+
class ClimateNet(nn.Module):
|
22 |
+
def __init__(self, input_size=(256, 256), output_size=(64, 64)):
|
23 |
+
super(ClimateNet, self).__init__()
|
24 |
+
self.input_size = input_size
|
25 |
+
self.output_size = output_size
|
26 |
+
|
27 |
+
# Feature map sizes after two max pooling layers
|
28 |
+
self.feature_size = (input_size[0] // 4, input_size[1] // 4)
|
29 |
+
|
30 |
+
# Improved RGB Encoder with residual connections
|
31 |
+
self.rgb_encoder = nn.Sequential(
|
32 |
+
nn.Conv2d(3, 64, kernel_size=3, padding=1),
|
33 |
+
nn.BatchNorm2d(64),
|
34 |
+
nn.ReLU(),
|
35 |
+
nn.Conv2d(64, 64, kernel_size=3, padding=1),
|
36 |
+
nn.BatchNorm2d(64),
|
37 |
+
nn.ReLU(),
|
38 |
+
nn.MaxPool2d(2),
|
39 |
+
nn.Dropout2d(0.2),
|
40 |
+
|
41 |
+
nn.Conv2d(64, 128, kernel_size=3, padding=1),
|
42 |
+
nn.BatchNorm2d(128),
|
43 |
+
nn.ReLU(),
|
44 |
+
nn.Conv2d(128, 128, kernel_size=3, padding=1),
|
45 |
+
nn.BatchNorm2d(128),
|
46 |
+
nn.ReLU(),
|
47 |
+
nn.MaxPool2d(2),
|
48 |
+
nn.Dropout2d(0.2)
|
49 |
+
)
|
50 |
+
|
51 |
+
# Improved NDVI Encoder
|
52 |
+
self.ndvi_encoder = nn.Sequential(
|
53 |
+
nn.Conv2d(1, 64, kernel_size=3, padding=1),
|
54 |
+
nn.BatchNorm2d(64),
|
55 |
+
nn.ReLU(),
|
56 |
+
nn.Conv2d(64, 64, kernel_size=3, padding=1),
|
57 |
+
nn.BatchNorm2d(64),
|
58 |
+
nn.ReLU(),
|
59 |
+
nn.MaxPool2d(2),
|
60 |
+
nn.Dropout2d(0.2),
|
61 |
+
|
62 |
+
nn.Conv2d(64, 128, kernel_size=3, padding=1),
|
63 |
+
nn.BatchNorm2d(128),
|
64 |
+
nn.ReLU(),
|
65 |
+
nn.MaxPool2d(2),
|
66 |
+
nn.Dropout2d(0.2)
|
67 |
+
)
|
68 |
+
|
69 |
+
# Improved Terrain Encoder
|
70 |
+
self.terrain_encoder = nn.Sequential(
|
71 |
+
nn.Conv2d(1, 64, kernel_size=3, padding=1),
|
72 |
+
nn.BatchNorm2d(64),
|
73 |
+
nn.ReLU(),
|
74 |
+
nn.Conv2d(64, 64, kernel_size=3, padding=1),
|
75 |
+
nn.BatchNorm2d(64),
|
76 |
+
nn.ReLU(),
|
77 |
+
nn.MaxPool2d(2),
|
78 |
+
nn.Dropout2d(0.2),
|
79 |
+
|
80 |
+
nn.Conv2d(64, 128, kernel_size=3, padding=1),
|
81 |
+
nn.BatchNorm2d(128),
|
82 |
+
nn.ReLU(),
|
83 |
+
nn.MaxPool2d(2),
|
84 |
+
nn.Dropout2d(0.2)
|
85 |
+
)
|
86 |
+
|
87 |
+
# Improved Weather Encoder with deeper architecture
|
88 |
+
self.weather_encoder = nn.Sequential(
|
89 |
+
nn.Linear(4, 64),
|
90 |
+
nn.ReLU(),
|
91 |
+
nn.Dropout(0.2),
|
92 |
+
nn.Linear(64, 128),
|
93 |
+
nn.ReLU(),
|
94 |
+
nn.Dropout(0.2),
|
95 |
+
nn.Linear(128, 128)
|
96 |
+
)
|
97 |
+
|
98 |
+
# Improved Feature Fusion
|
99 |
+
self.fusion = nn.Sequential(
|
100 |
+
nn.Conv2d(512, 512, kernel_size=1),
|
101 |
+
nn.BatchNorm2d(512),
|
102 |
+
nn.ReLU(),
|
103 |
+
nn.Dropout2d(0.2),
|
104 |
+
nn.Conv2d(512, 512, kernel_size=1),
|
105 |
+
nn.BatchNorm2d(512),
|
106 |
+
nn.ReLU()
|
107 |
+
)
|
108 |
+
|
109 |
+
# Improved Decoders with skip connections
|
110 |
+
self.wind_decoder = nn.Sequential(
|
111 |
+
nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),
|
112 |
+
nn.BatchNorm2d(256),
|
113 |
+
nn.ReLU(),
|
114 |
+
nn.Dropout2d(0.2),
|
115 |
+
nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
|
116 |
+
nn.BatchNorm2d(128),
|
117 |
+
nn.ReLU(),
|
118 |
+
nn.Dropout2d(0.2),
|
119 |
+
nn.Conv2d(128, 64, kernel_size=3, padding=1),
|
120 |
+
nn.BatchNorm2d(64),
|
121 |
+
nn.ReLU(),
|
122 |
+
nn.Conv2d(64, 1, kernel_size=1),
|
123 |
+
nn.Sigmoid()
|
124 |
+
)
|
125 |
+
|
126 |
+
self.solar_decoder = nn.Sequential(
|
127 |
+
nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),
|
128 |
+
nn.BatchNorm2d(256),
|
129 |
+
nn.ReLU(),
|
130 |
+
nn.Dropout2d(0.2),
|
131 |
+
nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
|
132 |
+
nn.BatchNorm2d(128),
|
133 |
+
nn.ReLU(),
|
134 |
+
nn.Dropout2d(0.2),
|
135 |
+
nn.Conv2d(128, 64, kernel_size=3, padding=1),
|
136 |
+
nn.BatchNorm2d(64),
|
137 |
+
nn.ReLU(),
|
138 |
+
nn.Conv2d(64, 1, kernel_size=1),
|
139 |
+
nn.Sigmoid()
|
140 |
+
)
|
141 |
+
|
142 |
+
def forward(self, x):
|
143 |
+
batch_size = x['rgb'].size(0)
|
144 |
+
|
145 |
+
# Resize all inputs to input_size
|
146 |
+
rgb_input = F.interpolate(x['rgb'], size=self.input_size, mode='bilinear', align_corners=False)
|
147 |
+
ndvi_input = F.interpolate(x['ndvi'], size=self.input_size, mode='bilinear', align_corners=False)
|
148 |
+
terrain_input = F.interpolate(x['terrain'], size=self.input_size, mode='bilinear', align_corners=False)
|
149 |
+
|
150 |
+
# Extract features
|
151 |
+
rgb_features = self.rgb_encoder(rgb_input) # [B, 128, H/4, W/4]
|
152 |
+
ndvi_features = self.ndvi_encoder(ndvi_input) # [B, 128, H/4, W/4]
|
153 |
+
terrain_features = self.terrain_encoder(terrain_input) # [B, 128, H/4, W/4]
|
154 |
+
|
155 |
+
# Process weather features and expand to match feature map size
|
156 |
+
weather_features = self.weather_encoder(x['weather_features']) # [B, 128]
|
157 |
+
weather_features = weather_features.view(batch_size, 128, 1, 1)
|
158 |
+
weather_features = F.interpolate(
|
159 |
+
weather_features,
|
160 |
+
size=self.feature_size,
|
161 |
+
mode='nearest'
|
162 |
+
)
|
163 |
+
|
164 |
+
# Combine features
|
165 |
+
combined_features = torch.cat([
|
166 |
+
rgb_features,
|
167 |
+
ndvi_features,
|
168 |
+
terrain_features,
|
169 |
+
weather_features
|
170 |
+
], dim=1)
|
171 |
+
|
172 |
+
# Apply fusion
|
173 |
+
fused_features = self.fusion(combined_features)
|
174 |
+
|
175 |
+
# Generate predictions and resize to output_size
|
176 |
+
wind_heatmap = self.wind_decoder(fused_features)
|
177 |
+
solar_heatmap = self.solar_decoder(fused_features)
|
178 |
+
|
179 |
+
wind_heatmap = F.interpolate(wind_heatmap, size=self.output_size, mode='bilinear', align_corners=False)
|
180 |
+
solar_heatmap = F.interpolate(solar_heatmap, size=self.output_size, mode='bilinear', align_corners=False)
|
181 |
+
|
182 |
+
return wind_heatmap, solar_heatmap
|
183 |
+
|
184 |
+
class ClimatePredictor:
|
185 |
+
def __init__(self, model_path, device=None):
|
186 |
+
if device is None:
|
187 |
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
188 |
+
else:
|
189 |
+
self.device = device
|
190 |
+
|
191 |
+
print(f"Using device: {self.device}")
|
192 |
+
|
193 |
+
# Load model
|
194 |
+
self.model = ClimateNet(input_size=(256, 256), output_size=(64, 64)).to(self.device)
|
195 |
+
checkpoint = torch.load(model_path, map_location=self.device)
|
196 |
+
|
197 |
+
if "module" in list(checkpoint['model_state_dict'].keys())[0]:
|
198 |
+
self.model = torch.nn.DataParallel(self.model)
|
199 |
+
|
200 |
+
self.model.load_state_dict(checkpoint['model_state_dict'])
|
201 |
+
self.model.eval()
|
202 |
+
|
203 |
+
self.rgb_transform = transforms.Compose([
|
204 |
+
transforms.Resize((256, 256)),
|
205 |
+
transforms.ToTensor(),
|
206 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
207 |
+
])
|
208 |
+
|
209 |
+
self.single_channel_transform = transforms.Compose([
|
210 |
+
transforms.Resize((256, 256)),
|
211 |
+
transforms.ToTensor(),
|
212 |
+
transforms.Normalize(mean=[0.5], std=[0.5])
|
213 |
+
])
|
214 |
+
|
215 |
+
def predict_from_inputs(self, rgb_image, ndvi_image, terrain_image,
|
216 |
+
elevation_data, wind_speed, wind_direction,
|
217 |
+
temperature, humidity):
|
218 |
+
"""Gradio ์ธํฐํ์ด์ค์ฉ ์์ธก ํจ์"""
|
219 |
+
# ์ด๋ฏธ์ง ์ ์ฒ๋ฆฌ
|
220 |
+
rgb_tensor = self.rgb_transform(Image.fromarray(rgb_image)).unsqueeze(0)
|
221 |
+
ndvi_tensor = self.single_channel_transform(Image.fromarray(ndvi_image)).unsqueeze(0)
|
222 |
+
terrain_tensor = self.single_channel_transform(Image.fromarray(terrain_image)).unsqueeze(0)
|
223 |
+
|
224 |
+
# ๊ณ ๋ ๋ฐ์ดํฐ ์ฒ๋ฆฌ
|
225 |
+
elevation_tensor = torch.from_numpy(elevation_data).float().unsqueeze(0).unsqueeze(0)
|
226 |
+
elevation_tensor = (elevation_tensor - elevation_tensor.min()) / (elevation_tensor.max() - elevation_tensor.min())
|
227 |
+
|
228 |
+
# ๊ธฐ์ ๋ฐ์ดํฐ ์ฒ๋ฆฌ
|
229 |
+
weather_features = np.array([wind_speed, wind_direction, temperature, humidity])
|
230 |
+
weather_features = (weather_features - weather_features.min()) / (weather_features.max() - weather_features.min())
|
231 |
+
weather_features = torch.tensor(weather_features, dtype=torch.float32).unsqueeze(0)
|
232 |
+
|
233 |
+
# ๋๋ฐ์ด์ค๋ก ์ด๋
|
234 |
+
sample = {
|
235 |
+
'rgb': rgb_tensor.to(self.device),
|
236 |
+
'ndvi': ndvi_tensor.to(self.device),
|
237 |
+
'terrain': terrain_tensor.to(self.device),
|
238 |
+
'elevation': elevation_tensor.to(self.device),
|
239 |
+
'weather_features': weather_features.to(self.device)
|
240 |
+
}
|
241 |
+
|
242 |
+
# ์์ธก
|
243 |
+
with torch.no_grad():
|
244 |
+
wind_pred, solar_pred = self.model(sample)
|
245 |
+
|
246 |
+
# ๊ฒฐ๊ณผ๋ฅผ numpy ๋ฐฐ์ด๋ก ๋ณํ
|
247 |
+
wind_map = wind_pred.cpu().numpy()[0, 0]
|
248 |
+
solar_map = solar_pred.cpu().numpy()[0, 0]
|
249 |
+
|
250 |
+
# ๊ฒฐ๊ณผ ์๊ฐํ
|
251 |
+
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
|
252 |
+
|
253 |
+
# ํ๋ ฅ ๋ฐ์ ์ ์ฌ๋ ์๊ฐํ
|
254 |
+
sns.heatmap(wind_map, ax=ax1, cmap='YlOrRd', cbar_kws={'label': 'Wind Power Potential'})
|
255 |
+
ax1.set_title('Wind Power Potential Map')
|
256 |
+
|
257 |
+
# ํ์๊ด ๋ฐ์ ์ ์ฌ๋ ์๊ฐํ
|
258 |
+
sns.heatmap(solar_map, ax=ax2, cmap='YlOrRd', cbar_kws={'label': 'Solar Power Potential'})
|
259 |
+
ax2.set_title('Solar Power Potential Map')
|
260 |
+
|
261 |
+
plt.tight_layout()
|
262 |
+
|
263 |
+
return fig
|
264 |
+
|
265 |
+
# Gradio ์ธํฐํ์ด์ค ์์ฑ
|
266 |
+
def create_gradio_interface():
|
267 |
+
predictor = ClimatePredictor('best_model.pth')
|
268 |
+
|
269 |
+
def predict_and_visualize(rgb_image, ndvi_image, terrain_image, elevation_file,
|
270 |
+
wind_speed, wind_direction, temperature, humidity):
|
271 |
+
# Load elevation data
|
272 |
+
elevation_data = np.load(elevation_file.name)
|
273 |
+
|
274 |
+
# Generate prediction and visualization
|
275 |
+
result = predictor.predict_from_inputs(
|
276 |
+
rgb_image, ndvi_image, terrain_image, elevation_data,
|
277 |
+
wind_speed, wind_direction, temperature, humidity
|
278 |
+
)
|
279 |
+
return result
|
280 |
+
|
281 |
+
interface = gr.Interface(
|
282 |
+
fn=predict_and_visualize,
|
283 |
+
inputs=[
|
284 |
+
gr.Image(label="RGB Satellite Image", type="numpy"),
|
285 |
+
gr.Image(label="NDVI Image", type="numpy"),
|
286 |
+
gr.Image(label="Terrain Map", type="numpy"),
|
287 |
+
gr.File(label="Elevation Data (NPY file)"),
|
288 |
+
gr.Number(label="Wind Speed (m/s)", value=5.0),
|
289 |
+
gr.Number(label="Wind Direction (degrees)", value=180.0),
|
290 |
+
gr.Number(label="Temperature (ยฐC)", value=25.0),
|
291 |
+
gr.Number(label="Humidity (%)", value=60.0)
|
292 |
+
],
|
293 |
+
outputs=gr.Plot(label="Prediction Results"),
|
294 |
+
title="Renewable Energy Potential Predictor",
|
295 |
+
description="Upload satellite imagery and environmental data to predict wind and solar power potential.",
|
296 |
+
examples=[
|
297 |
+
[
|
298 |
+
"examples/rgb_example.png",
|
299 |
+
"examples/ndvi_example.png",
|
300 |
+
"examples/terrain_example.png",
|
301 |
+
"examples/elevation_example.npy",
|
302 |
+
5.0, 180.0, 25.0, 60.0
|
303 |
+
]
|
304 |
+
]
|
305 |
+
)
|
306 |
+
return interface
|
307 |
+
|
308 |
+
if __name__ == "__main__":
|
309 |
+
interface = create_gradio_interface()
|
310 |
+
interface.launch()s
|