NewEnergy / app.py
HoeioUser's picture
Create app.py
3f71eec verified
raw
history blame
11.6 kB
import gradio as gr
import torch
import numpy as np
from PIL import Image
import torchvision.transforms as transforms
import pandas as pd
import os
import matplotlib.pyplot as plt
import seaborn as sns
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau
import numpy as np
import pandas as pd
from PIL import Image
import os
import torchvision.transforms as transforms
from sklearn.model_selection import train_test_split
class ClimateNet(nn.Module):
def __init__(self, input_size=(256, 256), output_size=(64, 64)):
super(ClimateNet, self).__init__()
self.input_size = input_size
self.output_size = output_size
# Feature map sizes after two max pooling layers
self.feature_size = (input_size[0] // 4, input_size[1] // 4)
# Improved RGB Encoder with residual connections
self.rgb_encoder = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.Conv2d(64, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Dropout2d(0.2),
nn.Conv2d(64, 128, kernel_size=3, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(),
nn.Conv2d(128, 128, kernel_size=3, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Dropout2d(0.2)
)
# Improved NDVI Encoder
self.ndvi_encoder = nn.Sequential(
nn.Conv2d(1, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.Conv2d(64, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Dropout2d(0.2),
nn.Conv2d(64, 128, kernel_size=3, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Dropout2d(0.2)
)
# Improved Terrain Encoder
self.terrain_encoder = nn.Sequential(
nn.Conv2d(1, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.Conv2d(64, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Dropout2d(0.2),
nn.Conv2d(64, 128, kernel_size=3, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Dropout2d(0.2)
)
# Improved Weather Encoder with deeper architecture
self.weather_encoder = nn.Sequential(
nn.Linear(4, 64),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(64, 128),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(128, 128)
)
# Improved Feature Fusion
self.fusion = nn.Sequential(
nn.Conv2d(512, 512, kernel_size=1),
nn.BatchNorm2d(512),
nn.ReLU(),
nn.Dropout2d(0.2),
nn.Conv2d(512, 512, kernel_size=1),
nn.BatchNorm2d(512),
nn.ReLU()
)
# Improved Decoders with skip connections
self.wind_decoder = nn.Sequential(
nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(),
nn.Dropout2d(0.2),
nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(),
nn.Dropout2d(0.2),
nn.Conv2d(128, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.Conv2d(64, 1, kernel_size=1),
nn.Sigmoid()
)
self.solar_decoder = nn.Sequential(
nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(),
nn.Dropout2d(0.2),
nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(),
nn.Dropout2d(0.2),
nn.Conv2d(128, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.Conv2d(64, 1, kernel_size=1),
nn.Sigmoid()
)
def forward(self, x):
batch_size = x['rgb'].size(0)
# Resize all inputs to input_size
rgb_input = F.interpolate(x['rgb'], size=self.input_size, mode='bilinear', align_corners=False)
ndvi_input = F.interpolate(x['ndvi'], size=self.input_size, mode='bilinear', align_corners=False)
terrain_input = F.interpolate(x['terrain'], size=self.input_size, mode='bilinear', align_corners=False)
# Extract features
rgb_features = self.rgb_encoder(rgb_input) # [B, 128, H/4, W/4]
ndvi_features = self.ndvi_encoder(ndvi_input) # [B, 128, H/4, W/4]
terrain_features = self.terrain_encoder(terrain_input) # [B, 128, H/4, W/4]
# Process weather features and expand to match feature map size
weather_features = self.weather_encoder(x['weather_features']) # [B, 128]
weather_features = weather_features.view(batch_size, 128, 1, 1)
weather_features = F.interpolate(
weather_features,
size=self.feature_size,
mode='nearest'
)
# Combine features
combined_features = torch.cat([
rgb_features,
ndvi_features,
terrain_features,
weather_features
], dim=1)
# Apply fusion
fused_features = self.fusion(combined_features)
# Generate predictions and resize to output_size
wind_heatmap = self.wind_decoder(fused_features)
solar_heatmap = self.solar_decoder(fused_features)
wind_heatmap = F.interpolate(wind_heatmap, size=self.output_size, mode='bilinear', align_corners=False)
solar_heatmap = F.interpolate(solar_heatmap, size=self.output_size, mode='bilinear', align_corners=False)
return wind_heatmap, solar_heatmap
class ClimatePredictor:
def __init__(self, model_path, device=None):
if device is None:
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
else:
self.device = device
print(f"Using device: {self.device}")
# Load model
self.model = ClimateNet(input_size=(256, 256), output_size=(64, 64)).to(self.device)
checkpoint = torch.load(model_path, map_location=self.device)
if "module" in list(checkpoint['model_state_dict'].keys())[0]:
self.model = torch.nn.DataParallel(self.model)
self.model.load_state_dict(checkpoint['model_state_dict'])
self.model.eval()
self.rgb_transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
self.single_channel_transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5], std=[0.5])
])
def predict_from_inputs(self, rgb_image, ndvi_image, terrain_image,
elevation_data, wind_speed, wind_direction,
temperature, humidity):
"""Gradio ์ธํ„ฐํŽ˜์ด์Šค์šฉ ์˜ˆ์ธก ํ•จ์ˆ˜"""
# ์ด๋ฏธ์ง€ ์ „์ฒ˜๋ฆฌ
rgb_tensor = self.rgb_transform(Image.fromarray(rgb_image)).unsqueeze(0)
ndvi_tensor = self.single_channel_transform(Image.fromarray(ndvi_image)).unsqueeze(0)
terrain_tensor = self.single_channel_transform(Image.fromarray(terrain_image)).unsqueeze(0)
# ๊ณ ๋„ ๋ฐ์ดํ„ฐ ์ฒ˜๋ฆฌ
elevation_tensor = torch.from_numpy(elevation_data).float().unsqueeze(0).unsqueeze(0)
elevation_tensor = (elevation_tensor - elevation_tensor.min()) / (elevation_tensor.max() - elevation_tensor.min())
# ๊ธฐ์ƒ ๋ฐ์ดํ„ฐ ์ฒ˜๋ฆฌ
weather_features = np.array([wind_speed, wind_direction, temperature, humidity])
weather_features = (weather_features - weather_features.min()) / (weather_features.max() - weather_features.min())
weather_features = torch.tensor(weather_features, dtype=torch.float32).unsqueeze(0)
# ๋””๋ฐ”์ด์Šค๋กœ ์ด๋™
sample = {
'rgb': rgb_tensor.to(self.device),
'ndvi': ndvi_tensor.to(self.device),
'terrain': terrain_tensor.to(self.device),
'elevation': elevation_tensor.to(self.device),
'weather_features': weather_features.to(self.device)
}
# ์˜ˆ์ธก
with torch.no_grad():
wind_pred, solar_pred = self.model(sample)
# ๊ฒฐ๊ณผ๋ฅผ numpy ๋ฐฐ์—ด๋กœ ๋ณ€ํ™˜
wind_map = wind_pred.cpu().numpy()[0, 0]
solar_map = solar_pred.cpu().numpy()[0, 0]
# ๊ฒฐ๊ณผ ์‹œ๊ฐํ™”
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
# ํ’๋ ฅ ๋ฐœ์ „ ์ž ์žฌ๋Ÿ‰ ์‹œ๊ฐํ™”
sns.heatmap(wind_map, ax=ax1, cmap='YlOrRd', cbar_kws={'label': 'Wind Power Potential'})
ax1.set_title('Wind Power Potential Map')
# ํƒœ์–‘๊ด‘ ๋ฐœ์ „ ์ž ์žฌ๋Ÿ‰ ์‹œ๊ฐํ™”
sns.heatmap(solar_map, ax=ax2, cmap='YlOrRd', cbar_kws={'label': 'Solar Power Potential'})
ax2.set_title('Solar Power Potential Map')
plt.tight_layout()
return fig
# Gradio ์ธํ„ฐํŽ˜์ด์Šค ์ƒ์„ฑ
def create_gradio_interface():
predictor = ClimatePredictor('best_model.pth')
def predict_and_visualize(rgb_image, ndvi_image, terrain_image, elevation_file,
wind_speed, wind_direction, temperature, humidity):
# Load elevation data
elevation_data = np.load(elevation_file.name)
# Generate prediction and visualization
result = predictor.predict_from_inputs(
rgb_image, ndvi_image, terrain_image, elevation_data,
wind_speed, wind_direction, temperature, humidity
)
return result
interface = gr.Interface(
fn=predict_and_visualize,
inputs=[
gr.Image(label="RGB Satellite Image", type="numpy"),
gr.Image(label="NDVI Image", type="numpy"),
gr.Image(label="Terrain Map", type="numpy"),
gr.File(label="Elevation Data (NPY file)"),
gr.Number(label="Wind Speed (m/s)", value=5.0),
gr.Number(label="Wind Direction (degrees)", value=180.0),
gr.Number(label="Temperature (ยฐC)", value=25.0),
gr.Number(label="Humidity (%)", value=60.0)
],
outputs=gr.Plot(label="Prediction Results"),
title="Renewable Energy Potential Predictor",
description="Upload satellite imagery and environmental data to predict wind and solar power potential.",
examples=[
[
"examples/rgb_example.png",
"examples/ndvi_example.png",
"examples/terrain_example.png",
"examples/elevation_example.npy",
5.0, 180.0, 25.0, 60.0
]
]
)
return interface
if __name__ == "__main__":
interface = create_gradio_interface()
interface.launch()s