NewEnergy / app.py
HoeioUser's picture
Update app.py
b9bea0d verified
raw
history blame
14.1 kB
import gradio as gr
import torch
import numpy as np
from PIL import Image
import torchvision.transforms as transforms
import pandas as pd
import os
import matplotlib.pyplot as plt
import seaborn as sns
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau
import numpy as np
import pandas as pd
from PIL import Image
import os
import torchvision.transforms as transforms
from sklearn.model_selection import train_test_split
import glob
class ClimateNet(nn.Module):
def __init__(self, input_size=(256, 256), output_size=(64, 64)):
super(ClimateNet, self).__init__()
self.input_size = input_size
self.output_size = output_size
# Feature map sizes after two max pooling layers
self.feature_size = (input_size[0] // 4, input_size[1] // 4)
# Improved RGB Encoder with residual connections
self.rgb_encoder = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.Conv2d(64, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Dropout2d(0.2),
nn.Conv2d(64, 128, kernel_size=3, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(),
nn.Conv2d(128, 128, kernel_size=3, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Dropout2d(0.2)
)
# Improved NDVI Encoder
self.ndvi_encoder = nn.Sequential(
nn.Conv2d(1, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.Conv2d(64, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Dropout2d(0.2),
nn.Conv2d(64, 128, kernel_size=3, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Dropout2d(0.2)
)
# Improved Terrain Encoder
self.terrain_encoder = nn.Sequential(
nn.Conv2d(1, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.Conv2d(64, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Dropout2d(0.2),
nn.Conv2d(64, 128, kernel_size=3, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Dropout2d(0.2)
)
# Improved Weather Encoder with deeper architecture
self.weather_encoder = nn.Sequential(
nn.Linear(4, 64),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(64, 128),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(128, 128)
)
# Improved Feature Fusion
self.fusion = nn.Sequential(
nn.Conv2d(512, 512, kernel_size=1),
nn.BatchNorm2d(512),
nn.ReLU(),
nn.Dropout2d(0.2),
nn.Conv2d(512, 512, kernel_size=1),
nn.BatchNorm2d(512),
nn.ReLU()
)
# Improved Decoders with skip connections
self.wind_decoder = nn.Sequential(
nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(),
nn.Dropout2d(0.2),
nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(),
nn.Dropout2d(0.2),
nn.Conv2d(128, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.Conv2d(64, 1, kernel_size=1),
nn.Sigmoid()
)
self.solar_decoder = nn.Sequential(
nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(),
nn.Dropout2d(0.2),
nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(),
nn.Dropout2d(0.2),
nn.Conv2d(128, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.Conv2d(64, 1, kernel_size=1),
nn.Sigmoid()
)
def forward(self, x):
batch_size = x['rgb'].size(0)
# Resize all inputs to input_size
rgb_input = F.interpolate(x['rgb'], size=self.input_size, mode='bilinear', align_corners=False)
ndvi_input = F.interpolate(x['ndvi'], size=self.input_size, mode='bilinear', align_corners=False)
terrain_input = F.interpolate(x['terrain'], size=self.input_size, mode='bilinear', align_corners=False)
# Extract features
rgb_features = self.rgb_encoder(rgb_input) # [B, 128, H/4, W/4]
ndvi_features = self.ndvi_encoder(ndvi_input) # [B, 128, H/4, W/4]
terrain_features = self.terrain_encoder(terrain_input) # [B, 128, H/4, W/4]
# Process weather features and expand to match feature map size
weather_features = self.weather_encoder(x['weather_features']) # [B, 128]
weather_features = weather_features.view(batch_size, 128, 1, 1)
weather_features = F.interpolate(
weather_features,
size=self.feature_size,
mode='nearest'
)
# Combine features
combined_features = torch.cat([
rgb_features,
ndvi_features,
terrain_features,
weather_features
], dim=1)
# Apply fusion
fused_features = self.fusion(combined_features)
# Generate predictions and resize to output_size
wind_heatmap = self.wind_decoder(fused_features)
solar_heatmap = self.solar_decoder(fused_features)
wind_heatmap = F.interpolate(wind_heatmap, size=self.output_size, mode='bilinear', align_corners=False)
solar_heatmap = F.interpolate(solar_heatmap, size=self.output_size, mode='bilinear', align_corners=False)
return wind_heatmap, solar_heatmap
class ClimatePredictor:
def __init__(self, model_path, device=None):
if device is None:
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
else:
self.device = device
print(f"Using device: {self.device}")
# Load model
self.model = ClimateNet(input_size=(256, 256), output_size=(64, 64)).to(self.device)
checkpoint = torch.load(model_path, map_location=self.device)
if "module" in list(checkpoint['model_state_dict'].keys())[0]:
self.model = torch.nn.DataParallel(self.model)
self.model.load_state_dict(checkpoint['model_state_dict'])
self.model.eval()
self.rgb_transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
self.single_channel_transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5], std=[0.5])
])
def convert_to_single_channel(self, image_array):
"""RGB ์ด๋ฏธ์ง€๋ฅผ ๋‹จ์ผ ์ฑ„๋„๋กœ ๋ณ€ํ™˜"""
if len(image_array.shape) == 3:
return np.dot(image_array[...,:3], [0.2989, 0.5870, 0.1140])
return image_array
def predict_from_inputs(self, rgb_image, ndvi_image, terrain_image,
elevation_data, wind_speed, wind_direction,
temperature, humidity):
"""Gradio ์ธํ„ฐํŽ˜์ด์Šค์šฉ ์˜ˆ์ธก ํ•จ์ˆ˜"""
try:
# RGB ์ด๋ฏธ์ง€ ์ „์ฒ˜๋ฆฌ
rgb_tensor = self.rgb_transform(Image.fromarray(rgb_image)).unsqueeze(0)
# NDVI ์ด๋ฏธ์ง€๋ฅผ ๋‹จ์ผ ์ฑ„๋„๋กœ ๋ณ€ํ™˜ ํ›„ ์ „์ฒ˜๋ฆฌ
ndvi_gray = self.convert_to_single_channel(ndvi_image)
ndvi_tensor = self.single_channel_transform(Image.fromarray(ndvi_gray.astype(np.uint8))).unsqueeze(0)
# Terrain ์ด๋ฏธ์ง€๋ฅผ ๋‹จ์ผ ์ฑ„๋„๋กœ ๋ณ€ํ™˜ ํ›„ ์ „์ฒ˜๋ฆฌ
terrain_gray = self.convert_to_single_channel(terrain_image)
terrain_tensor = self.single_channel_transform(Image.fromarray(terrain_gray.astype(np.uint8))).unsqueeze(0)
# ๊ณ ๋„ ๋ฐ์ดํ„ฐ ์ฒ˜๋ฆฌ
elevation_tensor = torch.from_numpy(elevation_data).float().unsqueeze(0).unsqueeze(0)
elevation_tensor = (elevation_tensor - elevation_tensor.min()) / (elevation_tensor.max() - elevation_tensor.min())
# ๊ธฐ์ƒ ๋ฐ์ดํ„ฐ ์ฒ˜๋ฆฌ
weather_features = np.array([wind_speed, wind_direction, temperature, humidity])
weather_features = (weather_features - weather_features.min()) / (weather_features.max() - weather_features.min())
weather_features = torch.tensor(weather_features, dtype=torch.float32).unsqueeze(0)
# ๋””๋ฐ”์ด์Šค๋กœ ์ด๋™
sample = {
'rgb': rgb_tensor.to(self.device),
'ndvi': ndvi_tensor.to(self.device),
'terrain': terrain_tensor.to(self.device),
'elevation': elevation_tensor.to(self.device),
'weather_features': weather_features.to(self.device)
}
# ์˜ˆ์ธก
with torch.no_grad():
wind_pred, solar_pred = self.model(sample)
# ๊ฒฐ๊ณผ๋ฅผ numpy ๋ฐฐ์—ด๋กœ ๋ณ€ํ™˜
wind_map = wind_pred.cpu().numpy()[0, 0]
solar_map = solar_pred.cpu().numpy()[0, 0]
# ๊ฒฐ๊ณผ ์‹œ๊ฐํ™”
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
sns.heatmap(wind_map, ax=ax1, cmap='YlOrRd', cbar_kws={'label': 'Wind Power Potential'})
ax1.set_title('Wind Power Potential Map')
sns.heatmap(solar_map, ax=ax2, cmap='YlOrRd', cbar_kws={'label': 'Solar Power Potential'})
ax2.set_title('Solar Power Potential Map')
plt.tight_layout()
return fig
except Exception as e:
print(f"Error in prediction: {str(e)}")
raise e
def load_examples_from_directory(base_dir):
"""ํด๋”์—์„œ ์˜ˆ์ œ ๋ฐ์ดํ„ฐ ๋กœ๋“œ"""
examples = []
sample_dirs = sorted(glob.glob(os.path.join(base_dir, "sample_*")))
for sample_dir in sample_dirs:
try:
# ํŒŒ์ผ ๊ฒฝ๋กœ ๊ตฌ์„ฑ
rgb_path = os.path.join(sample_dir, "satellite", "sentinel2_rgb_2023-07-15_to_2023-09-01.png")
ndvi_path = os.path.join(sample_dir, "satellite", "sentinel2_ndvi_2023-07-15_to_2023-09-01.png")
terrain_path = os.path.join(sample_dir, "terrain", "terrain_map.png")
elevation_path = os.path.join(sample_dir, "terrain", "elevation_data.npy")
weather_path = os.path.join(sample_dir, "weather", "weather_data.csv")
# ๊ธฐ์ƒ ๋ฐ์ดํ„ฐ ์ฝ๊ธฐ
weather_data = pd.read_csv(weather_path)
wind_speed = weather_data['wind_speed'].mean()
wind_direction = weather_data['wind_direction'].mean()
temperature = weather_data['temperature'].mean()
humidity = weather_data['humidity'].mean()
# ์˜ˆ์ œ ๋ฆฌ์ŠคํŠธ์— ์ถ”๊ฐ€
examples.append([
rgb_path,
ndvi_path,
terrain_path,
elevation_path,
float(wind_speed),
float(wind_direction),
float(temperature),
float(humidity)
])
except Exception as e:
print(f"Error loading example from {sample_dir}: {str(e)}")
continue
return examples
def create_gradio_interface():
predictor = ClimatePredictor('best_model.pth')
def predict_and_visualize(rgb_path, ndvi_path, terrain_path, elevation_path,
wind_speed, wind_direction, temperature, humidity):
# ์ด๋ฏธ์ง€ ๋กœ๋“œ
rgb_image = np.array(Image.open(rgb_path))
ndvi_image = np.array(Image.open(ndvi_path))
terrain_image = np.array(Image.open(terrain_path))
elevation_data = np.load(elevation_path)
# ์˜ˆ์ธก ๋ฐ ์‹œ๊ฐํ™”
result = predictor.predict_from_inputs(
rgb_image, ndvi_image, terrain_image, elevation_data,
wind_speed, wind_direction, temperature, humidity
)
return result
# ์˜ˆ์ œ ๋ฐ์ดํ„ฐ ๋กœ๋“œ
examples = load_examples_from_directory("filtered_climate_data")
print(f"Loaded {len(examples)} examples")
interface = gr.Interface(
fn=predict_and_visualize,
inputs=[
gr.Image(label="RGB Satellite Image", type="filepath"),
gr.Image(label="NDVI Image", type="filepath"),
gr.Image(label="Terrain Map", type="filepath"),
gr.File(label="Elevation Data (NPY file)"),
gr.Number(label="Wind Speed (m/s)", value=5.0),
gr.Number(label="Wind Direction (degrees)", value=180.0),
gr.Number(label="Temperature (ยฐC)", value=25.0),
gr.Number(label="Humidity (%)", value=60.0)
],
outputs=gr.Plot(label="Prediction Results"),
title="Renewable Energy Potential Predictor",
description="""Upload satellite imagery and environmental data to predict wind and solar power potential.
You can also try various examples from our dataset using the Examples section below.""",
examples=examples,
cache_examples=True
)
return interface
if __name__ == "__main__":
interface = create_gradio_interface()
interface.launch()