NewEnergy / app.py
HoeioUser's picture
Update app.py
5dbc592 verified
import gradio as gr
import torch
import numpy as np
from PIL import Image, ImageDraw
import torchvision.transforms as transforms
import pandas as pd
import os
import matplotlib.pyplot as plt
import seaborn as sns
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau
from sklearn.model_selection import train_test_split
import glob
import sys
import csv
maxInt = sys.maxsize
while True:
try:
csv.field_size_limit(maxInt)
break
except OverflowError:
maxInt = int(maxInt/10)
class ClimateNet(nn.Module):
def __init__(self, input_size=(256, 256), output_size=(64, 64)):
super(ClimateNet, self).__init__()
self.input_size = input_size
self.output_size = output_size
self.feature_size = (input_size[0] // 4, input_size[1] // 4)
self.rgb_encoder = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.Conv2d(64, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Dropout2d(0.2),
nn.Conv2d(64, 128, kernel_size=3, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(),
nn.Conv2d(128, 128, kernel_size=3, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Dropout2d(0.2)
)
self.ndvi_encoder = nn.Sequential(
nn.Conv2d(1, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.Conv2d(64, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Dropout2d(0.2),
nn.Conv2d(64, 128, kernel_size=3, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Dropout2d(0.2)
)
self.terrain_encoder = nn.Sequential(
nn.Conv2d(1, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.Conv2d(64, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Dropout2d(0.2),
nn.Conv2d(64, 128, kernel_size=3, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Dropout2d(0.2)
)
self.weather_encoder = nn.Sequential(
nn.Linear(4, 64),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(64, 128),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(128, 128)
)
self.fusion = nn.Sequential(
nn.Conv2d(512, 512, kernel_size=1),
nn.BatchNorm2d(512),
nn.ReLU(),
nn.Dropout2d(0.2),
nn.Conv2d(512, 512, kernel_size=1),
nn.BatchNorm2d(512),
nn.ReLU()
)
self.wind_decoder = nn.Sequential(
nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(),
nn.Dropout2d(0.2),
nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(),
nn.Dropout2d(0.2),
nn.Conv2d(128, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.Conv2d(64, 1, kernel_size=1),
nn.Sigmoid()
)
self.solar_decoder = nn.Sequential(
nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(),
nn.Dropout2d(0.2),
nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(),
nn.Dropout2d(0.2),
nn.Conv2d(128, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.Conv2d(64, 1, kernel_size=1),
nn.Sigmoid()
)
def forward(self, x):
batch_size = x['rgb'].size(0)
rgb_input = F.interpolate(x['rgb'], size=self.input_size, mode='bilinear', align_corners=False)
ndvi_input = F.interpolate(x['ndvi'], size=self.input_size, mode='bilinear', align_corners=False)
terrain_input = F.interpolate(x['terrain'], size=self.input_size, mode='bilinear', align_corners=False)
rgb_features = self.rgb_encoder(rgb_input)
ndvi_features = self.ndvi_encoder(ndvi_input)
terrain_features = self.terrain_encoder(terrain_input)
weather_features = self.weather_encoder(x['weather_features'])
weather_features = weather_features.view(batch_size, 128, 1, 1)
weather_features = F.interpolate(
weather_features,
size=self.feature_size,
mode='nearest'
)
combined_features = torch.cat([
rgb_features,
ndvi_features,
terrain_features,
weather_features
], dim=1)
fused_features = self.fusion(combined_features)
wind_heatmap = self.wind_decoder(fused_features)
solar_heatmap = self.solar_decoder(fused_features)
wind_heatmap = F.interpolate(wind_heatmap, size=self.output_size, mode='bilinear', align_corners=False)
solar_heatmap = F.interpolate(solar_heatmap, size=self.output_size, mode='bilinear', align_corners=False)
return wind_heatmap, solar_heatmap
def get_top_percentile_mask(data, percentile=95):
"""์ƒ์œ„ N%์— ํ•ด๋‹นํ•˜๋Š” ๋งˆ์Šคํฌ ์ƒ์„ฑ"""
threshold = np.percentile(data, percentile)
return data >= threshold
class ClimatePredictor:
def __init__(self, model_path, device=None):
if device is None:
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
else:
self.device = device
print(f"Using device: {self.device}")
self.model = ClimateNet(input_size=(256, 256), output_size=(64, 64)).to(self.device)
checkpoint = torch.load(model_path, map_location=self.device, weights_only=True) # weights_only=True ์ถ”๊ฐ€
if "module" in list(checkpoint['model_state_dict'].keys())[0]:
self.model = torch.nn.DataParallel(self.model)
self.model.load_state_dict(checkpoint['model_state_dict'])
self.model.eval()
self.rgb_transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
self.single_channel_transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5], std=[0.5])
])
def highlight_top_potential(self, rgb_image, wind_map, solar_map, percentile=95):
"""์ƒ์œ„ ํฌํ…์…œ ์ง€์—ญ์„ ํ•˜์ด๋ผ์ดํŠธ๋กœ ํ‘œ์‹œ"""
# RGB ์ด๋ฏธ์ง€๋ฅผ ๊ธฐ๋ณธ์œผ๋กœ ์‚ฌ์šฉ
result = np.copy(rgb_image)
# ํžˆํŠธ๋งต ๋ฆฌ์‚ฌ์ด์ฆˆ
h, w = rgb_image.shape[:2]
wind_map_resized = np.array(Image.fromarray((wind_map * 255).astype(np.uint8)).resize((w, h))) / 255.0
solar_map_resized = np.array(Image.fromarray((solar_map * 255).astype(np.uint8)).resize((w, h))) / 255.0
# ์ƒ์œ„ N% ๋งˆ์Šคํฌ ์ƒ์„ฑ
wind_threshold = np.percentile(wind_map_resized, percentile)
solar_threshold = np.percentile(solar_map_resized, percentile)
wind_mask = wind_map_resized >= wind_threshold
solar_mask = solar_map_resized >= solar_threshold
# ํ•˜์ด๋ผ์ดํŠธ ์ƒ‰์ƒ ์„ค์ •
wind_color = np.array([0, 255, 0]) # ๋…น์ƒ‰
solar_color = np.array([255, 0, 0]) # ๋นจ๊ฐ„์ƒ‰
# ๋ฐ˜ํˆฌ๋ช… ์˜ค๋ฒ„๋ ˆ์ด ์ ์šฉ
alpha = 0.3
result[wind_mask] = result[wind_mask] * (1 - alpha) + wind_color * alpha
result[solar_mask] = result[solar_mask] * (1 - alpha) + solar_color * alpha
return result.astype(np.uint8)
def convert_to_single_channel(self, image_array):
if len(image_array.shape) == 3:
return np.dot(image_array[...,:3], [0.2989, 0.5870, 0.1140])
return image_array
def overlay_emojis(self, rgb_image, wind_map, solar_map):
"""์ƒ์œ„ 5% ํฌํ…์…œ์—๋งŒ ์ด๋ชจ์ง€ ์˜ค๋ฒ„๋ ˆ์ด"""
img = Image.fromarray(rgb_image)
draw = ImageDraw.Draw(img)
h, w = rgb_image.shape[:2]
wind_map_resized = Image.fromarray((wind_map * 255).astype(np.uint8)).resize((w, h))
solar_map_resized = Image.fromarray((solar_map * 255).astype(np.uint8)).resize((w, h))
wind_map_np = np.array(wind_map_resized) / 255.0
solar_map_np = np.array(solar_map_resized) / 255.0
# ์ƒ์œ„ 5% ๋งˆ์Šคํฌ ์ƒ์„ฑ
wind_mask = get_top_percentile_mask(wind_map_np)
solar_mask = get_top_percentile_mask(solar_map_np)
emoji_size = min(w, h) // 20
grid_step = emoji_size * 2
for y in range(0, h - emoji_size, grid_step):
for x in range(0, w - emoji_size, grid_step):
region_wind = wind_mask[y:y+emoji_size, x:x+emoji_size].mean()
region_solar = solar_mask[y:y+emoji_size, x:x+emoji_size].mean()
text = ""
if region_wind > 0.5: # ์˜์—ญ์˜ 50% ์ด์ƒ์ด ์ƒ์œ„ 5%์— ์†ํ•˜๋ฉด ์ด๋ชจ์ง€ ํ‘œ์‹œ
text += "๐Ÿ’จ"
if region_solar > 0.5:
text += "โ˜€๏ธ"
if text:
draw.text((x, y), text, fill="white")
return np.array(img)
def predict_from_inputs(self, rgb_image, ndvi_image, terrain_image,
elevation_data, wind_speed, wind_direction,
temperature, humidity):
try:
# RGB ์ด๋ฏธ์ง€ ์ „์ฒ˜๋ฆฌ
rgb_tensor = self.rgb_transform(Image.fromarray(rgb_image)).unsqueeze(0)
# NDVI ์ด๋ฏธ์ง€๋ฅผ ๋‹จ์ผ ์ฑ„๋„๋กœ ๋ณ€ํ™˜ ํ›„ ์ „์ฒ˜๋ฆฌ
ndvi_gray = self.convert_to_single_channel(ndvi_image)
ndvi_tensor = self.single_channel_transform(Image.fromarray(ndvi_gray.astype(np.uint8))).unsqueeze(0)
# Terrain ์ด๋ฏธ์ง€๋ฅผ ๋‹จ์ผ ์ฑ„๋„๋กœ ๋ณ€ํ™˜ ํ›„ ์ „์ฒ˜๋ฆฌ
terrain_gray = self.convert_to_single_channel(terrain_image)
terrain_tensor = self.single_channel_transform(Image.fromarray(terrain_gray.astype(np.uint8))).unsqueeze(0)
# ๊ณ ๋„ ๋ฐ์ดํ„ฐ ์ฒ˜๋ฆฌ
elevation_tensor = torch.from_numpy(elevation_data).float().unsqueeze(0).unsqueeze(0)
elevation_tensor = (elevation_tensor - elevation_tensor.min()) / (elevation_tensor.max() - elevation_tensor.min())
# ๊ธฐ์ƒ ๋ฐ์ดํ„ฐ ์ฒ˜๋ฆฌ
weather_features = np.array([wind_speed, wind_direction, temperature, humidity])
weather_features = (weather_features - weather_features.min()) / (weather_features.max() - weather_features.min())
weather_features = torch.tensor(weather_features, dtype=torch.float32).unsqueeze(0)
# ๋””๋ฐ”์ด์Šค๋กœ ์ด๋™
sample = {
'rgb': rgb_tensor.to(self.device),
'ndvi': ndvi_tensor.to(self.device),
'terrain': terrain_tensor.to(self.device),
'elevation': elevation_tensor.to(self.device),
'weather_features': weather_features.to(self.device)
}
# ์˜ˆ์ธก
with torch.no_grad():
wind_pred, solar_pred = self.model(sample)
# ๊ฒฐ๊ณผ๋ฅผ numpy ๋ฐฐ์—ด๋กœ ๋ณ€ํ™˜
wind_map = wind_pred.cpu().numpy()[0, 0]
solar_map = solar_pred.cpu().numpy()[0, 0]
# ๊ฒฐ๊ณผ ์‹œ๊ฐํ™”
fig = plt.figure(figsize=(20, 12))
# 1. ์›๋ณธ ์ด๋ฏธ์ง€์™€ ์ƒ์œ„ 5% ํ•˜์ด๋ผ์ดํŠธ ์˜ค๋ฒ„๋ ˆ์ด
ax1 = plt.subplot(2, 2, 1)
highlighted_img = self.highlight_top_potential(rgb_image, wind_map, solar_map)
ax1.imshow(highlighted_img)
ax1.set_title('Top 5% Potential Sites\n(Red: Solar, Green: Wind)', pad=20)
ax1.axis('off')
# 2. ํ’๋ ฅ ๋ฐœ์ „ ์ž ์žฌ๋Ÿ‰ ํžˆํŠธ๋งต
ax2 = plt.subplot(2, 2, 2)
sns.heatmap(wind_map, ax=ax2, cmap='YlOrRd',
cbar_kws={'label': 'Wind Power Potential'})
ax2.set_title('Wind Power Potential Map')
# 3. ํƒœ์–‘๊ด‘ ๋ฐœ์ „ ์ž ์žฌ๋Ÿ‰ ํžˆํŠธ๋งต
ax3 = plt.subplot(2, 2, 3)
solar_heatmap = sns.heatmap(solar_map, ax=ax3, cmap='YlOrRd',
cbar_kws={'label': 'Solar Power Potential'})
ax3.set_title('Solar Power Potential Map')
# 4. ์ƒ์œ„ 5% ํฌํ…์…œ๋งŒ ํ‘œ์‹œํ•œ ํžˆํŠธ๋งต
ax4 = plt.subplot(2, 2, 4)
wind_top = np.where(wind_map >= np.percentile(wind_map, 95), wind_map, 0)
solar_top = np.where(solar_map >= np.percentile(solar_map, 95), solar_map, 0)
combined_map = np.stack([solar_top, wind_top, np.zeros_like(wind_map)], axis=-1)
ax4.imshow(combined_map)
ax4.set_title('Top 5% Potential Sites Heatmap\n(Red: Solar, Green: Wind)', pad=20)
ax4.axis('off')
plt.tight_layout()
return fig
except Exception as e:
print(f"Error in prediction: {str(e)}")
raise e
def load_examples_from_directory(base_dir):
"""ํด๋”์—์„œ ์˜ˆ์ œ ๋ฐ์ดํ„ฐ ๋กœ๋“œ - CSV ์ฒ˜๋ฆฌ ๊ฐœ์„ """
examples = []
sample_dirs = sorted(glob.glob(os.path.join(base_dir, "sample_*")))
for sample_dir in sample_dirs:
try:
rgb_path = os.path.join(sample_dir, "satellite", "sentinel2_rgb_2023-07-15_to_2023-09-01.png")
ndvi_path = os.path.join(sample_dir, "satellite", "sentinel2_ndvi_2023-07-15_to_2023-09-01.png")
terrain_path = os.path.join(sample_dir, "terrain", "terrain_map.png")
elevation_path = os.path.join(sample_dir, "terrain", "elevation_data.npy")
weather_path = os.path.join(sample_dir, "weather", "weather_data.csv")
# CSV ํŒŒ์ผ ์ฝ๊ธฐ ๊ฐœ์„ 
try:
weather_data = pd.read_csv(weather_path, engine='python')
except Exception as e:
print(f"Error reading CSV file {weather_path}: {str(e)}")
continue
wind_speed = weather_data['wind_speed'].mean()
wind_direction = weather_data['wind_direction'].mean()
temperature = weather_data['temperature'].mean()
humidity = weather_data['humidity'].mean()
examples.append([
rgb_path,
ndvi_path,
terrain_path,
elevation_path,
float(wind_speed),
float(wind_direction),
float(temperature),
float(humidity)
])
print(f"Successfully loaded example from {sample_dir}")
except Exception as e:
print(f"Error loading example from {sample_dir}: {str(e)}")
continue
print(f"Total examples loaded: {len(examples)}")
return examples
def create_gradio_interface():
predictor = ClimatePredictor('best_model.pth')
def process_elevation_file(file_obj):
if isinstance(file_obj, str):
return np.load(file_obj)
else:
return np.load(file_obj.name)
def predict_with_processing(*args):
rgb_image, ndvi_image, terrain_image, elevation_file = args[:4]
weather_params = args[4:]
elevation_data = process_elevation_file(elevation_file)
return predictor.predict_from_inputs(
rgb_image, ndvi_image, terrain_image, elevation_data,
*weather_params
)
with gr.Blocks(css="""
.contain {margin-left: auto; margin-right: auto}
.output-plot {min-height: 600px !important; width: 100% !important;}
""") as interface:
gr.Markdown("# Renewable Energy Potential Predictor")
# ์ž…๋ ฅ ์„น์…˜ - ํ•˜๋‚˜์˜ Row์— ๋ชจ๋“  ์ž…๋ ฅ ๋ฐฐ์น˜
with gr.Row(equal_height=True):
# ์ด๋ฏธ์ง€/ํŒŒ์ผ ์ž…๋ ฅ
rgb_input = gr.Image(label="RGB Satellite Image", type="numpy", height=150, scale=1)
ndvi_input = gr.Image(label="NDVI Image", type="numpy", height=150, scale=1)
terrain_input = gr.Image(label="Terrain Map", type="numpy", height=150, scale=1)
elevation_input = gr.File(label="Elevation Data (NPY)", scale=1)
# ๋‚ ์”จ ํŒŒ๋ผ๋ฏธํ„ฐ ์ž…๋ ฅ
with gr.Column(scale=1):
wind_speed = gr.Number(label="Wind Speed (m/s)", value=5.0)
wind_direction = gr.Number(label="Wind Direction (ยฐ)", value=180.0)
temperature = gr.Number(label="Temperature (ยฐC)", value=25.0)
humidity = gr.Number(label="Humidity (%)", value=60.0)
predict_btn = gr.Button("Generate Predictions", variant="primary", size="lg")
# ์ถœ๋ ฅ ์„น์…˜ - ์ „์ฒด ๋„ˆ๋น„ ์‚ฌ์šฉ
with gr.Column(elem_classes="output-plot"):
output_plot = gr.Plot(label="Prediction Results", container=True)
# ์˜ˆ์ œ ์„น์…˜
examples = load_examples_from_directory("filtered_climate_data")
gr.Examples(
examples=examples,
inputs=[rgb_input, ndvi_input, terrain_input, elevation_input,
wind_speed, wind_direction, temperature, humidity],
outputs=output_plot,
fn=predict_with_processing,
cache_examples=True,
label="Click any example to run",
examples_per_page=5
)
predict_btn.click(
fn=predict_with_processing,
inputs=[rgb_input, ndvi_input, terrain_input, elevation_input,
wind_speed, wind_direction, temperature, humidity],
outputs=output_plot
)
return interface
if __name__ == "__main__":
interface = create_gradio_interface()
interface.launch(
share=True,
server_port=7860,
server_name="0.0.0.0"
)