|
import gradio as gr |
|
import torch |
|
import numpy as np |
|
from PIL import Image |
|
import torchvision.transforms as transforms |
|
import pandas as pd |
|
import os |
|
import matplotlib.pyplot as plt |
|
import seaborn as sns |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torch.utils.data import Dataset, DataLoader |
|
from torch.optim.lr_scheduler import ReduceLROnPlateau |
|
import numpy as np |
|
import pandas as pd |
|
from PIL import Image |
|
import os |
|
import torchvision.transforms as transforms |
|
from sklearn.model_selection import train_test_split |
|
import glob |
|
|
|
class ClimateNet(nn.Module): |
|
def __init__(self, input_size=(256, 256), output_size=(64, 64)): |
|
super(ClimateNet, self).__init__() |
|
self.input_size = input_size |
|
self.output_size = output_size |
|
|
|
|
|
self.feature_size = (input_size[0] // 4, input_size[1] // 4) |
|
|
|
|
|
self.rgb_encoder = nn.Sequential( |
|
nn.Conv2d(3, 64, kernel_size=3, padding=1), |
|
nn.BatchNorm2d(64), |
|
nn.ReLU(), |
|
nn.Conv2d(64, 64, kernel_size=3, padding=1), |
|
nn.BatchNorm2d(64), |
|
nn.ReLU(), |
|
nn.MaxPool2d(2), |
|
nn.Dropout2d(0.2), |
|
|
|
nn.Conv2d(64, 128, kernel_size=3, padding=1), |
|
nn.BatchNorm2d(128), |
|
nn.ReLU(), |
|
nn.Conv2d(128, 128, kernel_size=3, padding=1), |
|
nn.BatchNorm2d(128), |
|
nn.ReLU(), |
|
nn.MaxPool2d(2), |
|
nn.Dropout2d(0.2) |
|
) |
|
|
|
|
|
self.ndvi_encoder = nn.Sequential( |
|
nn.Conv2d(1, 64, kernel_size=3, padding=1), |
|
nn.BatchNorm2d(64), |
|
nn.ReLU(), |
|
nn.Conv2d(64, 64, kernel_size=3, padding=1), |
|
nn.BatchNorm2d(64), |
|
nn.ReLU(), |
|
nn.MaxPool2d(2), |
|
nn.Dropout2d(0.2), |
|
|
|
nn.Conv2d(64, 128, kernel_size=3, padding=1), |
|
nn.BatchNorm2d(128), |
|
nn.ReLU(), |
|
nn.MaxPool2d(2), |
|
nn.Dropout2d(0.2) |
|
) |
|
|
|
|
|
self.terrain_encoder = nn.Sequential( |
|
nn.Conv2d(1, 64, kernel_size=3, padding=1), |
|
nn.BatchNorm2d(64), |
|
nn.ReLU(), |
|
nn.Conv2d(64, 64, kernel_size=3, padding=1), |
|
nn.BatchNorm2d(64), |
|
nn.ReLU(), |
|
nn.MaxPool2d(2), |
|
nn.Dropout2d(0.2), |
|
|
|
nn.Conv2d(64, 128, kernel_size=3, padding=1), |
|
nn.BatchNorm2d(128), |
|
nn.ReLU(), |
|
nn.MaxPool2d(2), |
|
nn.Dropout2d(0.2) |
|
) |
|
|
|
|
|
self.weather_encoder = nn.Sequential( |
|
nn.Linear(4, 64), |
|
nn.ReLU(), |
|
nn.Dropout(0.2), |
|
nn.Linear(64, 128), |
|
nn.ReLU(), |
|
nn.Dropout(0.2), |
|
nn.Linear(128, 128) |
|
) |
|
|
|
|
|
self.fusion = nn.Sequential( |
|
nn.Conv2d(512, 512, kernel_size=1), |
|
nn.BatchNorm2d(512), |
|
nn.ReLU(), |
|
nn.Dropout2d(0.2), |
|
nn.Conv2d(512, 512, kernel_size=1), |
|
nn.BatchNorm2d(512), |
|
nn.ReLU() |
|
) |
|
|
|
|
|
self.wind_decoder = nn.Sequential( |
|
nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1), |
|
nn.BatchNorm2d(256), |
|
nn.ReLU(), |
|
nn.Dropout2d(0.2), |
|
nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1), |
|
nn.BatchNorm2d(128), |
|
nn.ReLU(), |
|
nn.Dropout2d(0.2), |
|
nn.Conv2d(128, 64, kernel_size=3, padding=1), |
|
nn.BatchNorm2d(64), |
|
nn.ReLU(), |
|
nn.Conv2d(64, 1, kernel_size=1), |
|
nn.Sigmoid() |
|
) |
|
|
|
self.solar_decoder = nn.Sequential( |
|
nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1), |
|
nn.BatchNorm2d(256), |
|
nn.ReLU(), |
|
nn.Dropout2d(0.2), |
|
nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1), |
|
nn.BatchNorm2d(128), |
|
nn.ReLU(), |
|
nn.Dropout2d(0.2), |
|
nn.Conv2d(128, 64, kernel_size=3, padding=1), |
|
nn.BatchNorm2d(64), |
|
nn.ReLU(), |
|
nn.Conv2d(64, 1, kernel_size=1), |
|
nn.Sigmoid() |
|
) |
|
|
|
def forward(self, x): |
|
batch_size = x['rgb'].size(0) |
|
|
|
|
|
rgb_input = F.interpolate(x['rgb'], size=self.input_size, mode='bilinear', align_corners=False) |
|
ndvi_input = F.interpolate(x['ndvi'], size=self.input_size, mode='bilinear', align_corners=False) |
|
terrain_input = F.interpolate(x['terrain'], size=self.input_size, mode='bilinear', align_corners=False) |
|
|
|
|
|
rgb_features = self.rgb_encoder(rgb_input) |
|
ndvi_features = self.ndvi_encoder(ndvi_input) |
|
terrain_features = self.terrain_encoder(terrain_input) |
|
|
|
|
|
weather_features = self.weather_encoder(x['weather_features']) |
|
weather_features = weather_features.view(batch_size, 128, 1, 1) |
|
weather_features = F.interpolate( |
|
weather_features, |
|
size=self.feature_size, |
|
mode='nearest' |
|
) |
|
|
|
|
|
combined_features = torch.cat([ |
|
rgb_features, |
|
ndvi_features, |
|
terrain_features, |
|
weather_features |
|
], dim=1) |
|
|
|
|
|
fused_features = self.fusion(combined_features) |
|
|
|
|
|
wind_heatmap = self.wind_decoder(fused_features) |
|
solar_heatmap = self.solar_decoder(fused_features) |
|
|
|
wind_heatmap = F.interpolate(wind_heatmap, size=self.output_size, mode='bilinear', align_corners=False) |
|
solar_heatmap = F.interpolate(solar_heatmap, size=self.output_size, mode='bilinear', align_corners=False) |
|
|
|
return wind_heatmap, solar_heatmap |
|
|
|
class ClimatePredictor: |
|
def __init__(self, model_path, device=None): |
|
if device is None: |
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
else: |
|
self.device = device |
|
|
|
print(f"Using device: {self.device}") |
|
|
|
|
|
self.model = ClimateNet(input_size=(256, 256), output_size=(64, 64)).to(self.device) |
|
checkpoint = torch.load(model_path, map_location=self.device) |
|
|
|
if "module" in list(checkpoint['model_state_dict'].keys())[0]: |
|
self.model = torch.nn.DataParallel(self.model) |
|
|
|
self.model.load_state_dict(checkpoint['model_state_dict']) |
|
self.model.eval() |
|
|
|
self.rgb_transform = transforms.Compose([ |
|
transforms.Resize((256, 256)), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
|
]) |
|
|
|
self.single_channel_transform = transforms.Compose([ |
|
transforms.Resize((256, 256)), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=[0.5], std=[0.5]) |
|
]) |
|
|
|
def convert_to_single_channel(self, image_array): |
|
"""RGB ์ด๋ฏธ์ง๋ฅผ ๋จ์ผ ์ฑ๋๋ก ๋ณํ""" |
|
if len(image_array.shape) == 3: |
|
return np.dot(image_array[...,:3], [0.2989, 0.5870, 0.1140]) |
|
return image_array |
|
|
|
def predict_from_inputs(self, rgb_image, ndvi_image, terrain_image, |
|
elevation_data, wind_speed, wind_direction, |
|
temperature, humidity): |
|
"""Gradio ์ธํฐํ์ด์ค์ฉ ์์ธก ํจ์""" |
|
try: |
|
|
|
rgb_tensor = self.rgb_transform(Image.fromarray(rgb_image)).unsqueeze(0) |
|
|
|
|
|
ndvi_gray = self.convert_to_single_channel(ndvi_image) |
|
ndvi_tensor = self.single_channel_transform(Image.fromarray(ndvi_gray.astype(np.uint8))).unsqueeze(0) |
|
|
|
|
|
terrain_gray = self.convert_to_single_channel(terrain_image) |
|
terrain_tensor = self.single_channel_transform(Image.fromarray(terrain_gray.astype(np.uint8))).unsqueeze(0) |
|
|
|
|
|
elevation_tensor = torch.from_numpy(elevation_data).float().unsqueeze(0).unsqueeze(0) |
|
elevation_tensor = (elevation_tensor - elevation_tensor.min()) / (elevation_tensor.max() - elevation_tensor.min()) |
|
|
|
|
|
weather_features = np.array([wind_speed, wind_direction, temperature, humidity]) |
|
weather_features = (weather_features - weather_features.min()) / (weather_features.max() - weather_features.min()) |
|
weather_features = torch.tensor(weather_features, dtype=torch.float32).unsqueeze(0) |
|
|
|
|
|
sample = { |
|
'rgb': rgb_tensor.to(self.device), |
|
'ndvi': ndvi_tensor.to(self.device), |
|
'terrain': terrain_tensor.to(self.device), |
|
'elevation': elevation_tensor.to(self.device), |
|
'weather_features': weather_features.to(self.device) |
|
} |
|
|
|
|
|
with torch.no_grad(): |
|
wind_pred, solar_pred = self.model(sample) |
|
|
|
|
|
wind_map = wind_pred.cpu().numpy()[0, 0] |
|
solar_map = solar_pred.cpu().numpy()[0, 0] |
|
|
|
|
|
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5)) |
|
|
|
sns.heatmap(wind_map, ax=ax1, cmap='YlOrRd', cbar_kws={'label': 'Wind Power Potential'}) |
|
ax1.set_title('Wind Power Potential Map') |
|
|
|
sns.heatmap(solar_map, ax=ax2, cmap='YlOrRd', cbar_kws={'label': 'Solar Power Potential'}) |
|
ax2.set_title('Solar Power Potential Map') |
|
|
|
plt.tight_layout() |
|
|
|
return fig |
|
except Exception as e: |
|
print(f"Error in prediction: {str(e)}") |
|
raise e |
|
|
|
def load_examples_from_directory(base_dir): |
|
"""ํด๋์์ ์์ ๋ฐ์ดํฐ ๋ก๋""" |
|
examples = [] |
|
sample_dirs = sorted(glob.glob(os.path.join(base_dir, "sample_*"))) |
|
|
|
for sample_dir in sample_dirs: |
|
try: |
|
|
|
rgb_path = os.path.join(sample_dir, "satellite", "sentinel2_rgb_2023-07-15_to_2023-09-01.png") |
|
ndvi_path = os.path.join(sample_dir, "satellite", "sentinel2_ndvi_2023-07-15_to_2023-09-01.png") |
|
terrain_path = os.path.join(sample_dir, "terrain", "terrain_map.png") |
|
elevation_path = os.path.join(sample_dir, "terrain", "elevation_data.npy") |
|
weather_path = os.path.join(sample_dir, "weather", "weather_data.csv") |
|
|
|
|
|
weather_data = pd.read_csv(weather_path) |
|
wind_speed = weather_data['wind_speed'].mean() |
|
wind_direction = weather_data['wind_direction'].mean() |
|
temperature = weather_data['temperature'].mean() |
|
humidity = weather_data['humidity'].mean() |
|
|
|
|
|
examples.append([ |
|
rgb_path, |
|
ndvi_path, |
|
terrain_path, |
|
elevation_path, |
|
float(wind_speed), |
|
float(wind_direction), |
|
float(temperature), |
|
float(humidity) |
|
]) |
|
except Exception as e: |
|
print(f"Error loading example from {sample_dir}: {str(e)}") |
|
continue |
|
|
|
return examples |
|
|
|
def create_gradio_interface(): |
|
predictor = ClimatePredictor('best_model.pth') |
|
|
|
def predict_and_visualize(rgb_path, ndvi_path, terrain_path, elevation_path, |
|
wind_speed, wind_direction, temperature, humidity): |
|
|
|
rgb_image = np.array(Image.open(rgb_path)) |
|
ndvi_image = np.array(Image.open(ndvi_path)) |
|
terrain_image = np.array(Image.open(terrain_path)) |
|
elevation_data = np.load(elevation_path) |
|
|
|
|
|
result = predictor.predict_from_inputs( |
|
rgb_image, ndvi_image, terrain_image, elevation_data, |
|
wind_speed, wind_direction, temperature, humidity |
|
) |
|
return result |
|
|
|
|
|
examples = load_examples_from_directory("filtered_climate_data") |
|
print(f"Loaded {len(examples)} examples") |
|
|
|
interface = gr.Interface( |
|
fn=predict_and_visualize, |
|
inputs=[ |
|
gr.Image(label="RGB Satellite Image", type="filepath"), |
|
gr.Image(label="NDVI Image", type="filepath"), |
|
gr.Image(label="Terrain Map", type="filepath"), |
|
gr.File(label="Elevation Data (NPY file)"), |
|
gr.Number(label="Wind Speed (m/s)", value=5.0), |
|
gr.Number(label="Wind Direction (degrees)", value=180.0), |
|
gr.Number(label="Temperature (ยฐC)", value=25.0), |
|
gr.Number(label="Humidity (%)", value=60.0) |
|
], |
|
outputs=gr.Plot(label="Prediction Results"), |
|
title="Renewable Energy Potential Predictor", |
|
description="""Upload satellite imagery and environmental data to predict wind and solar power potential. |
|
You can also try various examples from our dataset using the Examples section below.""", |
|
examples=examples, |
|
cache_examples=True |
|
) |
|
return interface |
|
|
|
if __name__ == "__main__": |
|
interface = create_gradio_interface() |
|
interface.launch() |
|
|