|
import gradio as gr |
|
import torch |
|
import numpy as np |
|
from PIL import Image, ImageDraw |
|
import torchvision.transforms as transforms |
|
import pandas as pd |
|
import os |
|
import matplotlib.pyplot as plt |
|
import seaborn as sns |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torch.utils.data import Dataset, DataLoader |
|
from torch.optim.lr_scheduler import ReduceLROnPlateau |
|
from sklearn.model_selection import train_test_split |
|
import glob |
|
import sys |
|
import csv |
|
|
|
maxInt = sys.maxsize |
|
|
|
while True: |
|
try: |
|
csv.field_size_limit(maxInt) |
|
break |
|
except OverflowError: |
|
maxInt = int(maxInt/10) |
|
|
|
class ClimateNet(nn.Module): |
|
def __init__(self, input_size=(256, 256), output_size=(64, 64)): |
|
super(ClimateNet, self).__init__() |
|
self.input_size = input_size |
|
self.output_size = output_size |
|
self.feature_size = (input_size[0] // 4, input_size[1] // 4) |
|
|
|
self.rgb_encoder = nn.Sequential( |
|
nn.Conv2d(3, 64, kernel_size=3, padding=1), |
|
nn.BatchNorm2d(64), |
|
nn.ReLU(), |
|
nn.Conv2d(64, 64, kernel_size=3, padding=1), |
|
nn.BatchNorm2d(64), |
|
nn.ReLU(), |
|
nn.MaxPool2d(2), |
|
nn.Dropout2d(0.2), |
|
|
|
nn.Conv2d(64, 128, kernel_size=3, padding=1), |
|
nn.BatchNorm2d(128), |
|
nn.ReLU(), |
|
nn.Conv2d(128, 128, kernel_size=3, padding=1), |
|
nn.BatchNorm2d(128), |
|
nn.ReLU(), |
|
nn.MaxPool2d(2), |
|
nn.Dropout2d(0.2) |
|
) |
|
|
|
self.ndvi_encoder = nn.Sequential( |
|
nn.Conv2d(1, 64, kernel_size=3, padding=1), |
|
nn.BatchNorm2d(64), |
|
nn.ReLU(), |
|
nn.Conv2d(64, 64, kernel_size=3, padding=1), |
|
nn.BatchNorm2d(64), |
|
nn.ReLU(), |
|
nn.MaxPool2d(2), |
|
nn.Dropout2d(0.2), |
|
|
|
nn.Conv2d(64, 128, kernel_size=3, padding=1), |
|
nn.BatchNorm2d(128), |
|
nn.ReLU(), |
|
nn.MaxPool2d(2), |
|
nn.Dropout2d(0.2) |
|
) |
|
|
|
self.terrain_encoder = nn.Sequential( |
|
nn.Conv2d(1, 64, kernel_size=3, padding=1), |
|
nn.BatchNorm2d(64), |
|
nn.ReLU(), |
|
nn.Conv2d(64, 64, kernel_size=3, padding=1), |
|
nn.BatchNorm2d(64), |
|
nn.ReLU(), |
|
nn.MaxPool2d(2), |
|
nn.Dropout2d(0.2), |
|
|
|
nn.Conv2d(64, 128, kernel_size=3, padding=1), |
|
nn.BatchNorm2d(128), |
|
nn.ReLU(), |
|
nn.MaxPool2d(2), |
|
nn.Dropout2d(0.2) |
|
) |
|
|
|
self.weather_encoder = nn.Sequential( |
|
nn.Linear(4, 64), |
|
nn.ReLU(), |
|
nn.Dropout(0.2), |
|
nn.Linear(64, 128), |
|
nn.ReLU(), |
|
nn.Dropout(0.2), |
|
nn.Linear(128, 128) |
|
) |
|
|
|
self.fusion = nn.Sequential( |
|
nn.Conv2d(512, 512, kernel_size=1), |
|
nn.BatchNorm2d(512), |
|
nn.ReLU(), |
|
nn.Dropout2d(0.2), |
|
nn.Conv2d(512, 512, kernel_size=1), |
|
nn.BatchNorm2d(512), |
|
nn.ReLU() |
|
) |
|
|
|
self.wind_decoder = nn.Sequential( |
|
nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1), |
|
nn.BatchNorm2d(256), |
|
nn.ReLU(), |
|
nn.Dropout2d(0.2), |
|
nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1), |
|
nn.BatchNorm2d(128), |
|
nn.ReLU(), |
|
nn.Dropout2d(0.2), |
|
nn.Conv2d(128, 64, kernel_size=3, padding=1), |
|
nn.BatchNorm2d(64), |
|
nn.ReLU(), |
|
nn.Conv2d(64, 1, kernel_size=1), |
|
nn.Sigmoid() |
|
) |
|
|
|
self.solar_decoder = nn.Sequential( |
|
nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1), |
|
nn.BatchNorm2d(256), |
|
nn.ReLU(), |
|
nn.Dropout2d(0.2), |
|
nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1), |
|
nn.BatchNorm2d(128), |
|
nn.ReLU(), |
|
nn.Dropout2d(0.2), |
|
nn.Conv2d(128, 64, kernel_size=3, padding=1), |
|
nn.BatchNorm2d(64), |
|
nn.ReLU(), |
|
nn.Conv2d(64, 1, kernel_size=1), |
|
nn.Sigmoid() |
|
) |
|
|
|
def forward(self, x): |
|
batch_size = x['rgb'].size(0) |
|
|
|
rgb_input = F.interpolate(x['rgb'], size=self.input_size, mode='bilinear', align_corners=False) |
|
ndvi_input = F.interpolate(x['ndvi'], size=self.input_size, mode='bilinear', align_corners=False) |
|
terrain_input = F.interpolate(x['terrain'], size=self.input_size, mode='bilinear', align_corners=False) |
|
|
|
rgb_features = self.rgb_encoder(rgb_input) |
|
ndvi_features = self.ndvi_encoder(ndvi_input) |
|
terrain_features = self.terrain_encoder(terrain_input) |
|
|
|
weather_features = self.weather_encoder(x['weather_features']) |
|
weather_features = weather_features.view(batch_size, 128, 1, 1) |
|
weather_features = F.interpolate( |
|
weather_features, |
|
size=self.feature_size, |
|
mode='nearest' |
|
) |
|
|
|
combined_features = torch.cat([ |
|
rgb_features, |
|
ndvi_features, |
|
terrain_features, |
|
weather_features |
|
], dim=1) |
|
|
|
fused_features = self.fusion(combined_features) |
|
|
|
wind_heatmap = self.wind_decoder(fused_features) |
|
solar_heatmap = self.solar_decoder(fused_features) |
|
|
|
wind_heatmap = F.interpolate(wind_heatmap, size=self.output_size, mode='bilinear', align_corners=False) |
|
solar_heatmap = F.interpolate(solar_heatmap, size=self.output_size, mode='bilinear', align_corners=False) |
|
|
|
return wind_heatmap, solar_heatmap |
|
|
|
def get_top_percentile_mask(data, percentile=95): |
|
"""์์ N%์ ํด๋นํ๋ ๋ง์คํฌ ์์ฑ""" |
|
threshold = np.percentile(data, percentile) |
|
return data >= threshold |
|
|
|
class ClimatePredictor: |
|
def __init__(self, model_path, device=None): |
|
if device is None: |
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
else: |
|
self.device = device |
|
|
|
print(f"Using device: {self.device}") |
|
|
|
self.model = ClimateNet(input_size=(256, 256), output_size=(64, 64)).to(self.device) |
|
checkpoint = torch.load(model_path, map_location=self.device, weights_only=True) |
|
|
|
if "module" in list(checkpoint['model_state_dict'].keys())[0]: |
|
self.model = torch.nn.DataParallel(self.model) |
|
|
|
self.model.load_state_dict(checkpoint['model_state_dict']) |
|
self.model.eval() |
|
|
|
self.rgb_transform = transforms.Compose([ |
|
transforms.Resize((256, 256)), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
|
]) |
|
|
|
self.single_channel_transform = transforms.Compose([ |
|
transforms.Resize((256, 256)), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=[0.5], std=[0.5]) |
|
]) |
|
|
|
def highlight_top_potential(self, rgb_image, wind_map, solar_map, percentile=95): |
|
"""์์ ํฌํ
์
์ง์ญ์ ํ์ด๋ผ์ดํธ๋ก ํ์""" |
|
|
|
result = np.copy(rgb_image) |
|
|
|
|
|
h, w = rgb_image.shape[:2] |
|
wind_map_resized = np.array(Image.fromarray((wind_map * 255).astype(np.uint8)).resize((w, h))) / 255.0 |
|
solar_map_resized = np.array(Image.fromarray((solar_map * 255).astype(np.uint8)).resize((w, h))) / 255.0 |
|
|
|
|
|
wind_threshold = np.percentile(wind_map_resized, percentile) |
|
solar_threshold = np.percentile(solar_map_resized, percentile) |
|
|
|
wind_mask = wind_map_resized >= wind_threshold |
|
solar_mask = solar_map_resized >= solar_threshold |
|
|
|
|
|
wind_color = np.array([0, 255, 0]) |
|
solar_color = np.array([255, 0, 0]) |
|
|
|
|
|
alpha = 0.3 |
|
result[wind_mask] = result[wind_mask] * (1 - alpha) + wind_color * alpha |
|
result[solar_mask] = result[solar_mask] * (1 - alpha) + solar_color * alpha |
|
|
|
return result.astype(np.uint8) |
|
|
|
|
|
def convert_to_single_channel(self, image_array): |
|
if len(image_array.shape) == 3: |
|
return np.dot(image_array[...,:3], [0.2989, 0.5870, 0.1140]) |
|
return image_array |
|
|
|
def overlay_emojis(self, rgb_image, wind_map, solar_map): |
|
"""์์ 5% ํฌํ
์
์๋ง ์ด๋ชจ์ง ์ค๋ฒ๋ ์ด""" |
|
img = Image.fromarray(rgb_image) |
|
draw = ImageDraw.Draw(img) |
|
|
|
h, w = rgb_image.shape[:2] |
|
wind_map_resized = Image.fromarray((wind_map * 255).astype(np.uint8)).resize((w, h)) |
|
solar_map_resized = Image.fromarray((solar_map * 255).astype(np.uint8)).resize((w, h)) |
|
|
|
wind_map_np = np.array(wind_map_resized) / 255.0 |
|
solar_map_np = np.array(solar_map_resized) / 255.0 |
|
|
|
|
|
wind_mask = get_top_percentile_mask(wind_map_np) |
|
solar_mask = get_top_percentile_mask(solar_map_np) |
|
|
|
emoji_size = min(w, h) // 20 |
|
grid_step = emoji_size * 2 |
|
|
|
for y in range(0, h - emoji_size, grid_step): |
|
for x in range(0, w - emoji_size, grid_step): |
|
region_wind = wind_mask[y:y+emoji_size, x:x+emoji_size].mean() |
|
region_solar = solar_mask[y:y+emoji_size, x:x+emoji_size].mean() |
|
|
|
text = "" |
|
if region_wind > 0.5: |
|
text += "๐จ" |
|
if region_solar > 0.5: |
|
text += "โ๏ธ" |
|
|
|
if text: |
|
draw.text((x, y), text, fill="white") |
|
|
|
return np.array(img) |
|
|
|
|
|
def predict_from_inputs(self, rgb_image, ndvi_image, terrain_image, |
|
elevation_data, wind_speed, wind_direction, |
|
temperature, humidity): |
|
try: |
|
|
|
rgb_tensor = self.rgb_transform(Image.fromarray(rgb_image)).unsqueeze(0) |
|
|
|
|
|
ndvi_gray = self.convert_to_single_channel(ndvi_image) |
|
ndvi_tensor = self.single_channel_transform(Image.fromarray(ndvi_gray.astype(np.uint8))).unsqueeze(0) |
|
|
|
|
|
terrain_gray = self.convert_to_single_channel(terrain_image) |
|
terrain_tensor = self.single_channel_transform(Image.fromarray(terrain_gray.astype(np.uint8))).unsqueeze(0) |
|
|
|
|
|
elevation_tensor = torch.from_numpy(elevation_data).float().unsqueeze(0).unsqueeze(0) |
|
elevation_tensor = (elevation_tensor - elevation_tensor.min()) / (elevation_tensor.max() - elevation_tensor.min()) |
|
|
|
|
|
weather_features = np.array([wind_speed, wind_direction, temperature, humidity]) |
|
weather_features = (weather_features - weather_features.min()) / (weather_features.max() - weather_features.min()) |
|
weather_features = torch.tensor(weather_features, dtype=torch.float32).unsqueeze(0) |
|
|
|
|
|
sample = { |
|
'rgb': rgb_tensor.to(self.device), |
|
'ndvi': ndvi_tensor.to(self.device), |
|
'terrain': terrain_tensor.to(self.device), |
|
'elevation': elevation_tensor.to(self.device), |
|
'weather_features': weather_features.to(self.device) |
|
} |
|
|
|
|
|
with torch.no_grad(): |
|
wind_pred, solar_pred = self.model(sample) |
|
|
|
|
|
wind_map = wind_pred.cpu().numpy()[0, 0] |
|
solar_map = solar_pred.cpu().numpy()[0, 0] |
|
|
|
|
|
fig = plt.figure(figsize=(20, 12)) |
|
|
|
|
|
ax1 = plt.subplot(2, 2, 1) |
|
highlighted_img = self.highlight_top_potential(rgb_image, wind_map, solar_map) |
|
ax1.imshow(highlighted_img) |
|
ax1.set_title('Top 5% Potential Sites\n(Red: Solar, Green: Wind)', pad=20) |
|
ax1.axis('off') |
|
|
|
|
|
ax2 = plt.subplot(2, 2, 2) |
|
sns.heatmap(wind_map, ax=ax2, cmap='YlOrRd', |
|
cbar_kws={'label': 'Wind Power Potential'}) |
|
ax2.set_title('Wind Power Potential Map') |
|
|
|
|
|
ax3 = plt.subplot(2, 2, 3) |
|
solar_heatmap = sns.heatmap(solar_map, ax=ax3, cmap='YlOrRd', |
|
cbar_kws={'label': 'Solar Power Potential'}) |
|
ax3.set_title('Solar Power Potential Map') |
|
|
|
|
|
ax4 = plt.subplot(2, 2, 4) |
|
wind_top = np.where(wind_map >= np.percentile(wind_map, 95), wind_map, 0) |
|
solar_top = np.where(solar_map >= np.percentile(solar_map, 95), solar_map, 0) |
|
combined_map = np.stack([solar_top, wind_top, np.zeros_like(wind_map)], axis=-1) |
|
ax4.imshow(combined_map) |
|
ax4.set_title('Top 5% Potential Sites Heatmap\n(Red: Solar, Green: Wind)', pad=20) |
|
ax4.axis('off') |
|
|
|
plt.tight_layout() |
|
return fig |
|
|
|
except Exception as e: |
|
print(f"Error in prediction: {str(e)}") |
|
raise e |
|
|
|
def load_examples_from_directory(base_dir): |
|
"""ํด๋์์ ์์ ๋ฐ์ดํฐ ๋ก๋ - CSV ์ฒ๋ฆฌ ๊ฐ์ """ |
|
examples = [] |
|
sample_dirs = sorted(glob.glob(os.path.join(base_dir, "sample_*"))) |
|
|
|
for sample_dir in sample_dirs: |
|
try: |
|
rgb_path = os.path.join(sample_dir, "satellite", "sentinel2_rgb_2023-07-15_to_2023-09-01.png") |
|
ndvi_path = os.path.join(sample_dir, "satellite", "sentinel2_ndvi_2023-07-15_to_2023-09-01.png") |
|
terrain_path = os.path.join(sample_dir, "terrain", "terrain_map.png") |
|
elevation_path = os.path.join(sample_dir, "terrain", "elevation_data.npy") |
|
weather_path = os.path.join(sample_dir, "weather", "weather_data.csv") |
|
|
|
|
|
try: |
|
weather_data = pd.read_csv(weather_path, engine='python') |
|
except Exception as e: |
|
print(f"Error reading CSV file {weather_path}: {str(e)}") |
|
continue |
|
|
|
wind_speed = weather_data['wind_speed'].mean() |
|
wind_direction = weather_data['wind_direction'].mean() |
|
temperature = weather_data['temperature'].mean() |
|
humidity = weather_data['humidity'].mean() |
|
|
|
examples.append([ |
|
rgb_path, |
|
ndvi_path, |
|
terrain_path, |
|
elevation_path, |
|
float(wind_speed), |
|
float(wind_direction), |
|
float(temperature), |
|
float(humidity) |
|
]) |
|
print(f"Successfully loaded example from {sample_dir}") |
|
except Exception as e: |
|
print(f"Error loading example from {sample_dir}: {str(e)}") |
|
continue |
|
|
|
print(f"Total examples loaded: {len(examples)}") |
|
return examples |
|
|
|
|
|
def create_gradio_interface(): |
|
predictor = ClimatePredictor('best_model.pth') |
|
|
|
def process_elevation_file(file_obj): |
|
if isinstance(file_obj, str): |
|
return np.load(file_obj) |
|
else: |
|
return np.load(file_obj.name) |
|
|
|
def predict_with_processing(*args): |
|
rgb_image, ndvi_image, terrain_image, elevation_file = args[:4] |
|
weather_params = args[4:] |
|
|
|
elevation_data = process_elevation_file(elevation_file) |
|
|
|
return predictor.predict_from_inputs( |
|
rgb_image, ndvi_image, terrain_image, elevation_data, |
|
*weather_params |
|
) |
|
|
|
with gr.Blocks(css=""" |
|
.contain {margin-left: auto; margin-right: auto} |
|
.output-plot {min-height: 600px !important; width: 100% !important;} |
|
""") as interface: |
|
gr.Markdown("# Renewable Energy Potential Predictor") |
|
|
|
|
|
with gr.Row(equal_height=True): |
|
|
|
rgb_input = gr.Image(label="RGB Satellite Image", type="numpy", height=150, scale=1) |
|
ndvi_input = gr.Image(label="NDVI Image", type="numpy", height=150, scale=1) |
|
terrain_input = gr.Image(label="Terrain Map", type="numpy", height=150, scale=1) |
|
elevation_input = gr.File(label="Elevation Data (NPY)", scale=1) |
|
|
|
|
|
with gr.Column(scale=1): |
|
wind_speed = gr.Number(label="Wind Speed (m/s)", value=5.0) |
|
wind_direction = gr.Number(label="Wind Direction (ยฐ)", value=180.0) |
|
temperature = gr.Number(label="Temperature (ยฐC)", value=25.0) |
|
humidity = gr.Number(label="Humidity (%)", value=60.0) |
|
predict_btn = gr.Button("Generate Predictions", variant="primary", size="lg") |
|
|
|
|
|
with gr.Column(elem_classes="output-plot"): |
|
output_plot = gr.Plot(label="Prediction Results", container=True) |
|
|
|
|
|
examples = load_examples_from_directory("filtered_climate_data") |
|
gr.Examples( |
|
examples=examples, |
|
inputs=[rgb_input, ndvi_input, terrain_input, elevation_input, |
|
wind_speed, wind_direction, temperature, humidity], |
|
outputs=output_plot, |
|
fn=predict_with_processing, |
|
cache_examples=True, |
|
label="Click any example to run", |
|
examples_per_page=5 |
|
) |
|
|
|
predict_btn.click( |
|
fn=predict_with_processing, |
|
inputs=[rgb_input, ndvi_input, terrain_input, elevation_input, |
|
wind_speed, wind_direction, temperature, humidity], |
|
outputs=output_plot |
|
) |
|
|
|
return interface |
|
|
|
if __name__ == "__main__": |
|
interface = create_gradio_interface() |
|
interface.launch( |
|
share=True, |
|
server_port=7860, |
|
server_name="0.0.0.0" |
|
) |