HoeioUser commited on
Commit
3f71eec
ยท
verified ยท
1 Parent(s): 3164ac0

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +310 -0
app.py ADDED
@@ -0,0 +1,310 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import numpy as np
4
+ from PIL import Image
5
+ import torchvision.transforms as transforms
6
+ import pandas as pd
7
+ import os
8
+ import matplotlib.pyplot as plt
9
+ import seaborn as sns
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ from torch.utils.data import Dataset, DataLoader
13
+ from torch.optim.lr_scheduler import ReduceLROnPlateau
14
+ import numpy as np
15
+ import pandas as pd
16
+ from PIL import Image
17
+ import os
18
+ import torchvision.transforms as transforms
19
+ from sklearn.model_selection import train_test_split
20
+
21
+ class ClimateNet(nn.Module):
22
+ def __init__(self, input_size=(256, 256), output_size=(64, 64)):
23
+ super(ClimateNet, self).__init__()
24
+ self.input_size = input_size
25
+ self.output_size = output_size
26
+
27
+ # Feature map sizes after two max pooling layers
28
+ self.feature_size = (input_size[0] // 4, input_size[1] // 4)
29
+
30
+ # Improved RGB Encoder with residual connections
31
+ self.rgb_encoder = nn.Sequential(
32
+ nn.Conv2d(3, 64, kernel_size=3, padding=1),
33
+ nn.BatchNorm2d(64),
34
+ nn.ReLU(),
35
+ nn.Conv2d(64, 64, kernel_size=3, padding=1),
36
+ nn.BatchNorm2d(64),
37
+ nn.ReLU(),
38
+ nn.MaxPool2d(2),
39
+ nn.Dropout2d(0.2),
40
+
41
+ nn.Conv2d(64, 128, kernel_size=3, padding=1),
42
+ nn.BatchNorm2d(128),
43
+ nn.ReLU(),
44
+ nn.Conv2d(128, 128, kernel_size=3, padding=1),
45
+ nn.BatchNorm2d(128),
46
+ nn.ReLU(),
47
+ nn.MaxPool2d(2),
48
+ nn.Dropout2d(0.2)
49
+ )
50
+
51
+ # Improved NDVI Encoder
52
+ self.ndvi_encoder = nn.Sequential(
53
+ nn.Conv2d(1, 64, kernel_size=3, padding=1),
54
+ nn.BatchNorm2d(64),
55
+ nn.ReLU(),
56
+ nn.Conv2d(64, 64, kernel_size=3, padding=1),
57
+ nn.BatchNorm2d(64),
58
+ nn.ReLU(),
59
+ nn.MaxPool2d(2),
60
+ nn.Dropout2d(0.2),
61
+
62
+ nn.Conv2d(64, 128, kernel_size=3, padding=1),
63
+ nn.BatchNorm2d(128),
64
+ nn.ReLU(),
65
+ nn.MaxPool2d(2),
66
+ nn.Dropout2d(0.2)
67
+ )
68
+
69
+ # Improved Terrain Encoder
70
+ self.terrain_encoder = nn.Sequential(
71
+ nn.Conv2d(1, 64, kernel_size=3, padding=1),
72
+ nn.BatchNorm2d(64),
73
+ nn.ReLU(),
74
+ nn.Conv2d(64, 64, kernel_size=3, padding=1),
75
+ nn.BatchNorm2d(64),
76
+ nn.ReLU(),
77
+ nn.MaxPool2d(2),
78
+ nn.Dropout2d(0.2),
79
+
80
+ nn.Conv2d(64, 128, kernel_size=3, padding=1),
81
+ nn.BatchNorm2d(128),
82
+ nn.ReLU(),
83
+ nn.MaxPool2d(2),
84
+ nn.Dropout2d(0.2)
85
+ )
86
+
87
+ # Improved Weather Encoder with deeper architecture
88
+ self.weather_encoder = nn.Sequential(
89
+ nn.Linear(4, 64),
90
+ nn.ReLU(),
91
+ nn.Dropout(0.2),
92
+ nn.Linear(64, 128),
93
+ nn.ReLU(),
94
+ nn.Dropout(0.2),
95
+ nn.Linear(128, 128)
96
+ )
97
+
98
+ # Improved Feature Fusion
99
+ self.fusion = nn.Sequential(
100
+ nn.Conv2d(512, 512, kernel_size=1),
101
+ nn.BatchNorm2d(512),
102
+ nn.ReLU(),
103
+ nn.Dropout2d(0.2),
104
+ nn.Conv2d(512, 512, kernel_size=1),
105
+ nn.BatchNorm2d(512),
106
+ nn.ReLU()
107
+ )
108
+
109
+ # Improved Decoders with skip connections
110
+ self.wind_decoder = nn.Sequential(
111
+ nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),
112
+ nn.BatchNorm2d(256),
113
+ nn.ReLU(),
114
+ nn.Dropout2d(0.2),
115
+ nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
116
+ nn.BatchNorm2d(128),
117
+ nn.ReLU(),
118
+ nn.Dropout2d(0.2),
119
+ nn.Conv2d(128, 64, kernel_size=3, padding=1),
120
+ nn.BatchNorm2d(64),
121
+ nn.ReLU(),
122
+ nn.Conv2d(64, 1, kernel_size=1),
123
+ nn.Sigmoid()
124
+ )
125
+
126
+ self.solar_decoder = nn.Sequential(
127
+ nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),
128
+ nn.BatchNorm2d(256),
129
+ nn.ReLU(),
130
+ nn.Dropout2d(0.2),
131
+ nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
132
+ nn.BatchNorm2d(128),
133
+ nn.ReLU(),
134
+ nn.Dropout2d(0.2),
135
+ nn.Conv2d(128, 64, kernel_size=3, padding=1),
136
+ nn.BatchNorm2d(64),
137
+ nn.ReLU(),
138
+ nn.Conv2d(64, 1, kernel_size=1),
139
+ nn.Sigmoid()
140
+ )
141
+
142
+ def forward(self, x):
143
+ batch_size = x['rgb'].size(0)
144
+
145
+ # Resize all inputs to input_size
146
+ rgb_input = F.interpolate(x['rgb'], size=self.input_size, mode='bilinear', align_corners=False)
147
+ ndvi_input = F.interpolate(x['ndvi'], size=self.input_size, mode='bilinear', align_corners=False)
148
+ terrain_input = F.interpolate(x['terrain'], size=self.input_size, mode='bilinear', align_corners=False)
149
+
150
+ # Extract features
151
+ rgb_features = self.rgb_encoder(rgb_input) # [B, 128, H/4, W/4]
152
+ ndvi_features = self.ndvi_encoder(ndvi_input) # [B, 128, H/4, W/4]
153
+ terrain_features = self.terrain_encoder(terrain_input) # [B, 128, H/4, W/4]
154
+
155
+ # Process weather features and expand to match feature map size
156
+ weather_features = self.weather_encoder(x['weather_features']) # [B, 128]
157
+ weather_features = weather_features.view(batch_size, 128, 1, 1)
158
+ weather_features = F.interpolate(
159
+ weather_features,
160
+ size=self.feature_size,
161
+ mode='nearest'
162
+ )
163
+
164
+ # Combine features
165
+ combined_features = torch.cat([
166
+ rgb_features,
167
+ ndvi_features,
168
+ terrain_features,
169
+ weather_features
170
+ ], dim=1)
171
+
172
+ # Apply fusion
173
+ fused_features = self.fusion(combined_features)
174
+
175
+ # Generate predictions and resize to output_size
176
+ wind_heatmap = self.wind_decoder(fused_features)
177
+ solar_heatmap = self.solar_decoder(fused_features)
178
+
179
+ wind_heatmap = F.interpolate(wind_heatmap, size=self.output_size, mode='bilinear', align_corners=False)
180
+ solar_heatmap = F.interpolate(solar_heatmap, size=self.output_size, mode='bilinear', align_corners=False)
181
+
182
+ return wind_heatmap, solar_heatmap
183
+
184
+ class ClimatePredictor:
185
+ def __init__(self, model_path, device=None):
186
+ if device is None:
187
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
188
+ else:
189
+ self.device = device
190
+
191
+ print(f"Using device: {self.device}")
192
+
193
+ # Load model
194
+ self.model = ClimateNet(input_size=(256, 256), output_size=(64, 64)).to(self.device)
195
+ checkpoint = torch.load(model_path, map_location=self.device)
196
+
197
+ if "module" in list(checkpoint['model_state_dict'].keys())[0]:
198
+ self.model = torch.nn.DataParallel(self.model)
199
+
200
+ self.model.load_state_dict(checkpoint['model_state_dict'])
201
+ self.model.eval()
202
+
203
+ self.rgb_transform = transforms.Compose([
204
+ transforms.Resize((256, 256)),
205
+ transforms.ToTensor(),
206
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
207
+ ])
208
+
209
+ self.single_channel_transform = transforms.Compose([
210
+ transforms.Resize((256, 256)),
211
+ transforms.ToTensor(),
212
+ transforms.Normalize(mean=[0.5], std=[0.5])
213
+ ])
214
+
215
+ def predict_from_inputs(self, rgb_image, ndvi_image, terrain_image,
216
+ elevation_data, wind_speed, wind_direction,
217
+ temperature, humidity):
218
+ """Gradio ์ธํ„ฐํŽ˜์ด์Šค์šฉ ์˜ˆ์ธก ํ•จ์ˆ˜"""
219
+ # ์ด๋ฏธ์ง€ ์ „์ฒ˜๋ฆฌ
220
+ rgb_tensor = self.rgb_transform(Image.fromarray(rgb_image)).unsqueeze(0)
221
+ ndvi_tensor = self.single_channel_transform(Image.fromarray(ndvi_image)).unsqueeze(0)
222
+ terrain_tensor = self.single_channel_transform(Image.fromarray(terrain_image)).unsqueeze(0)
223
+
224
+ # ๊ณ ๋„ ๋ฐ์ดํ„ฐ ์ฒ˜๋ฆฌ
225
+ elevation_tensor = torch.from_numpy(elevation_data).float().unsqueeze(0).unsqueeze(0)
226
+ elevation_tensor = (elevation_tensor - elevation_tensor.min()) / (elevation_tensor.max() - elevation_tensor.min())
227
+
228
+ # ๊ธฐ์ƒ ๋ฐ์ดํ„ฐ ์ฒ˜๋ฆฌ
229
+ weather_features = np.array([wind_speed, wind_direction, temperature, humidity])
230
+ weather_features = (weather_features - weather_features.min()) / (weather_features.max() - weather_features.min())
231
+ weather_features = torch.tensor(weather_features, dtype=torch.float32).unsqueeze(0)
232
+
233
+ # ๋””๋ฐ”์ด์Šค๋กœ ์ด๋™
234
+ sample = {
235
+ 'rgb': rgb_tensor.to(self.device),
236
+ 'ndvi': ndvi_tensor.to(self.device),
237
+ 'terrain': terrain_tensor.to(self.device),
238
+ 'elevation': elevation_tensor.to(self.device),
239
+ 'weather_features': weather_features.to(self.device)
240
+ }
241
+
242
+ # ์˜ˆ์ธก
243
+ with torch.no_grad():
244
+ wind_pred, solar_pred = self.model(sample)
245
+
246
+ # ๊ฒฐ๊ณผ๋ฅผ numpy ๋ฐฐ์—ด๋กœ ๋ณ€ํ™˜
247
+ wind_map = wind_pred.cpu().numpy()[0, 0]
248
+ solar_map = solar_pred.cpu().numpy()[0, 0]
249
+
250
+ # ๊ฒฐ๊ณผ ์‹œ๊ฐํ™”
251
+ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
252
+
253
+ # ํ’๋ ฅ ๋ฐœ์ „ ์ž ์žฌ๋Ÿ‰ ์‹œ๊ฐํ™”
254
+ sns.heatmap(wind_map, ax=ax1, cmap='YlOrRd', cbar_kws={'label': 'Wind Power Potential'})
255
+ ax1.set_title('Wind Power Potential Map')
256
+
257
+ # ํƒœ์–‘๊ด‘ ๋ฐœ์ „ ์ž ์žฌ๋Ÿ‰ ์‹œ๊ฐํ™”
258
+ sns.heatmap(solar_map, ax=ax2, cmap='YlOrRd', cbar_kws={'label': 'Solar Power Potential'})
259
+ ax2.set_title('Solar Power Potential Map')
260
+
261
+ plt.tight_layout()
262
+
263
+ return fig
264
+
265
+ # Gradio ์ธํ„ฐํŽ˜์ด์Šค ์ƒ์„ฑ
266
+ def create_gradio_interface():
267
+ predictor = ClimatePredictor('best_model.pth')
268
+
269
+ def predict_and_visualize(rgb_image, ndvi_image, terrain_image, elevation_file,
270
+ wind_speed, wind_direction, temperature, humidity):
271
+ # Load elevation data
272
+ elevation_data = np.load(elevation_file.name)
273
+
274
+ # Generate prediction and visualization
275
+ result = predictor.predict_from_inputs(
276
+ rgb_image, ndvi_image, terrain_image, elevation_data,
277
+ wind_speed, wind_direction, temperature, humidity
278
+ )
279
+ return result
280
+
281
+ interface = gr.Interface(
282
+ fn=predict_and_visualize,
283
+ inputs=[
284
+ gr.Image(label="RGB Satellite Image", type="numpy"),
285
+ gr.Image(label="NDVI Image", type="numpy"),
286
+ gr.Image(label="Terrain Map", type="numpy"),
287
+ gr.File(label="Elevation Data (NPY file)"),
288
+ gr.Number(label="Wind Speed (m/s)", value=5.0),
289
+ gr.Number(label="Wind Direction (degrees)", value=180.0),
290
+ gr.Number(label="Temperature (ยฐC)", value=25.0),
291
+ gr.Number(label="Humidity (%)", value=60.0)
292
+ ],
293
+ outputs=gr.Plot(label="Prediction Results"),
294
+ title="Renewable Energy Potential Predictor",
295
+ description="Upload satellite imagery and environmental data to predict wind and solar power potential.",
296
+ examples=[
297
+ [
298
+ "examples/rgb_example.png",
299
+ "examples/ndvi_example.png",
300
+ "examples/terrain_example.png",
301
+ "examples/elevation_example.npy",
302
+ 5.0, 180.0, 25.0, 60.0
303
+ ]
304
+ ]
305
+ )
306
+ return interface
307
+
308
+ if __name__ == "__main__":
309
+ interface = create_gradio_interface()
310
+ interface.launch()s