qq1990's picture
init
100edb4
raw
history blame
13.3 kB
import streamlit as st
import torch
# from Pangu-Weather import *
import numpy as np
from datetime import datetime
import numpy as np
import onnx
import onnxruntime as ort
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import io
def pangu_config_data():
st.subheader("Pangu-Weather Model Data Input")
# Detailed data description section
st.markdown("""
**Input Data Requirements:**
Pangu-Weather uses two NumPy arrays to represent initial atmospheric conditions:
1. **Surface Data (input_surface.npy)**
- Shape: `(4, 721, 1440)`
- Variables: MSLP, U10, V10, T2M in this exact order.
- **MSLP:** Mean Sea Level Pressure
- **U10:** 10-meter Eastward Wind
- **V10:** 10-meter Northward Wind
- **T2M:** 2-meter Temperature
2. **Upper-Air Data (input_upper.npy)**
- Shape: `(5, 13, 721, 1440)`
- Variables (first dim): Z, Q, T, U, V in this exact order
- **Z:** Geopotential (Note: if your source provides geopotential height, multiply by 9.80665 to get geopotential)
- **Q:** Specific Humidity
- **T:** Temperature
- **U:** Eastward Wind
- **V:** Northward Wind
- Pressure Levels (second dim): 1000hPa, 925hPa, 850hPa, 700hPa, 600hPa, 500hPa, 400hPa, 300hPa, 250hPa, 200hPa, 150hPa, 100hPa, 50hPa in this exact order.
**Spatial & Coordinate Details:**
- Latitude dimension (721 points) ranges from 90°N to -90°S with a 0.25° spacing.
- Longitude dimension (1440 points) ranges from 0° to 359.75°E with a 0.25° spacing.
- Data should be single precision floats (`.astype(np.float32)`).
**Supported Data Sources:**
- ERA5 initial fields (strongly recommended).
- ECMWF initial fields (e.g., HRES forecast) can be used, but may result in a slight accuracy drop.
- Other types of initial fields are not currently supported due to potentially large discrepancies in data fields.
**Converting Your Data:**
- ERA5 `.nc` files can be converted to `.npy` using the `netCDF4` Python package.
- ECMWF `.grib` files can be converted to `.npy` using the `pygrib` Python package.
- Ensure the order of variables and pressure levels is exactly as described above.
""")
# File uploaders for surface and upper data separately
st.markdown("### Upload Your Input Data Files")
input_surface_file = st.file_uploader(
"Upload input_surface.npy",
type=["npy"],
key="pangu_input_surface"
)
input_upper_file = st.file_uploader(
"Upload input_upper.npy",
type=["npy"],
key="pangu_input_upper"
)
st.markdown("---")
st.markdown("### References & Resources")
st.markdown("""
- **Research Paper:** [Accurate medium-range global weather forecasting with 3D neural networks](https://www.nature.com/articles/s41586-023-06185-3)
- [Pangu-Weather: A 3D High-Resolution Model for Fast and Accurate Global Weather Forecast](https://arxiv.org/abs/2211.02556)
- **GitHub Source Code:** [Pangu-Weather on GitHub](https://github.com/198808xc/Pangu-Weather?tab=readme-ov-file)
""")
return input_surface_file, input_upper_file
def inference_24hrs(input, input_surface):
model_24 = onnx.load('Pangu-Weather/pangu_weather_24.onnx')
# Set the behavier of onnxruntime
options = ort.SessionOptions()
options.enable_cpu_mem_arena=False
options.enable_mem_pattern = False
options.enable_mem_reuse = False
# Increase the number for faster inference and more memory consumption
options.intra_op_num_threads = 1
# Set the behavier of cuda provider
cuda_provider_options = {'arena_extend_strategy':'kSameAsRequested',}
# Initialize onnxruntime session for Pangu-Weather Models
ort_session_24 = ort.InferenceSession('Pangu-Weather/pangu_weather_24.onnx', sess_options=options, providers=['CPUExecutionProvider'])
# Run the inference session
output, output_surface = ort_session_24.run(None, {'input':input, 'input_surface':input_surface})
return output, output_surface
@st.cache_resource
def inference_6hrs(input, input_surface):
model_6 = onnx.load('Pangu-Weather/pangu_weather_6.onnx')
# Set the behavier of onnxruntime
options = ort.SessionOptions()
options.enable_cpu_mem_arena=False
options.enable_mem_pattern = False
options.enable_mem_reuse = False
# Increase the number for faster inference and more memory consumption
options.intra_op_num_threads = 1
# Set the behavier of cuda provider
cuda_provider_options = {'arena_extend_strategy':'kSameAsRequested',}
# Initialize onnxruntime session for Pangu-Weather Models
ort_session_6 = ort.InferenceSession('Pangu-Weather/pangu_weather_6.onnx', sess_options=options, providers=['CPUExecutionProvider'])
# Run the inference session
output, output_surface = ort_session_6.run(None, {'input':input, 'input_surface':input_surface})
return output, output_surface
@st.cache_resource
def inference_1hr(input, input_surface):
model_1 = onnx.load('Pangu-Weather/pangu_weather_1.onnx')
# Set the behavier of onnxruntime
options = ort.SessionOptions()
options.enable_cpu_mem_arena=False
options.enable_mem_pattern = False
options.enable_mem_reuse = False
# Increase the number for faster inference and more memory consumption
options.intra_op_num_threads = 1
# Set the behavier of cuda provider
cuda_provider_options = {'arena_extend_strategy':'kSameAsRequested',}
# Initialize onnxruntime session for Pangu-Weather Models
ort_session_1 = ort.InferenceSession('Pangu-Weather/pangu_weather_1.onnx', sess_options=options, providers=['CPUExecutionProvider'])
# Run the inference session
output, output_surface = ort_session_1.run(None, {'input':input, 'input_surface':input_surface})
return output, output_surface
@st.cache_resource
def inference_3hrs(input, input_surface):
model_3 = onnx.load('Pangu-Weather/pangu_weather_3.onnx')
# Set the behavier of onnxruntime
options = ort.SessionOptions()
options.enable_cpu_mem_arena=False
options.enable_mem_pattern = False
options.enable_mem_reuse = False
# Increase the number for faster inference and more memory consumption
options.intra_op_num_threads = 1
# Set the behavier of cuda provider
cuda_provider_options = {'arena_extend_strategy':'kSameAsRequested',}
# Initialize onnxruntime session for Pangu-Weather Models
ort_session_3 = ort.InferenceSession('Pangu-Weather/pangu_weather_3.onnx', sess_options=options, providers=['CPUExecutionProvider'])
# Run the inference session
output, output_surface = ort_session_3.run(None, {'input':input, 'input_surface':input_surface})
return output, output_surface
@st.cache_resource
def inference_custom_hrs(input, input_surface, forecast_hours):
# Ensure forecast_hours is a multiple of 24
if forecast_hours % 24 != 0:
raise ValueError("forecast_hours must be a multiple of 24.")
# Load the 24-hour model
model_24 = onnx.load('Pangu-Weather/pangu_weather_24.onnx')
# Configure ONNX Runtime session
options = ort.SessionOptions()
options.enable_cpu_mem_arena = False
options.enable_mem_pattern = False
options.enable_mem_reuse = False
options.intra_op_num_threads = 1
# Using CPUExecutionProvider for simplicity
ort_session_24 = ort.InferenceSession('Pangu-Weather/pangu_weather_24.onnx', sess_options=options, providers=['CPUExecutionProvider'])
# Calculate how many 24-hour steps we need
steps = forecast_hours // 24
# Run the 24-hour model repeatedly
for i in range(steps):
output, output_surface = ort_session_24.run(None, {'input': input, 'input_surface': input_surface})
input, input_surface = output, output_surface
# Return the final predictions after completing all steps
return input, input_surface
def plot_pangu_output(upper_data, surface_data, out_upper, out_surface):
# Coordinate setup
lat = np.linspace(90, -90, 721) # Latitude grid
lon = np.linspace(0, 360, 1440) # Longitude grid
# Variable and level names
upper_vars = ["Z (Geopotential)", "Q (Specific Humidity)", "T (Temperature)", "U (Eastward Wind)", "V (Northward Wind)"]
upper_levels = ["1000hPa", "925hPa", "850hPa", "700hPa", "600hPa", "500hPa",
"400hPa", "300hPa", "250hPa", "200hPa", "150hPa", "100hPa", "50hPa"]
# Extract numeric hPa values for selection
upper_hpa_values = [int(l.replace("hPa", "")) for l in upper_levels]
surface_vars = ["MSLP", "U10", "V10", "T2M"]
# --- Initial Data Visualization ---
st.subheader("Initial Data Visualization")
init_col1, init_col2 = st.columns([1,1])
with init_col1:
init_data_choice = st.selectbox("Data Source", ["Upper-Air Data", "Surface Data"], key="init_data_choice")
with init_col2:
if init_data_choice == "Upper-Air Data":
init_var = st.selectbox("Variable", upper_vars, key="init_upper_var")
else:
init_var = st.selectbox("Variable", surface_vars, key="init_surface_var")
if init_data_choice == "Upper-Air Data":
selected_level_hpa_init = st.select_slider(
"Select Pressure Level (hPa)",
options=upper_hpa_values,
value=850, # Default to 850hPa
help="Select the pressure level in hPa.",
key="init_level_hpa_slider"
)
# Find the corresponding index from the selected hPa value
selected_level_index_init = upper_hpa_values.index(selected_level_hpa_init)
selected_var_index_init = upper_vars.index(init_var)
data_to_plot_init = upper_data[selected_var_index_init, selected_level_index_init, :, :]
title_init = f"Initial Upper-Air: {init_var} at {selected_level_hpa_init}hPa"
else:
selected_var_index_init = surface_vars.index(init_var)
data_to_plot_init = surface_data[selected_var_index_init, :, :]
title_init = f"Initial Surface: {init_var}"
# Plot initial data
fig_init, ax_init = plt.subplots(figsize=(10, 5), subplot_kw={'projection': ccrs.PlateCarree()})
ax_init.set_title(title_init)
im_init = ax_init.imshow(data_to_plot_init, extent=[lon.min(), lon.max(), lat.min(), lat.max()],
origin='lower', cmap='coolwarm', transform=ccrs.PlateCarree())
ax_init.coastlines()
plt.colorbar(im_init, ax=ax_init, orientation='horizontal', pad=0.05)
st.pyplot(fig_init)
# --- Predicted Data Visualization ---
st.subheader("Predicted Data Visualization")
pred_col1, pred_col2 = st.columns([1,1])
with pred_col1:
pred_data_choice = st.selectbox("Data Source", ["Upper-Air Data", "Surface Data"], key="pred_data_choice")
with pred_col2:
if pred_data_choice == "Upper-Air Data":
pred_var = st.selectbox("Variable", upper_vars, key="pred_upper_var")
else:
pred_var = st.selectbox("Variable", surface_vars, key="pred_surface_var")
if pred_data_choice == "Upper-Air Data":
selected_level_hpa_pred = st.select_slider(
"Select Pressure Level (hPa)",
options=upper_hpa_values,
value=850, # Default to 850hPa
help="Select the pressure level in hPa.",
key="pred_level_hpa_slider"
)
selected_level_index_pred = upper_hpa_values.index(selected_level_hpa_pred)
selected_var_index_pred = upper_vars.index(pred_var)
data_to_plot_pred = out_upper[selected_var_index_pred, selected_level_index_pred, :, :]
title_pred = f"Predicted Upper-Air: {pred_var} at {selected_level_hpa_pred}hPa"
else:
selected_var_index_pred = surface_vars.index(pred_var)
data_to_plot_pred = out_surface[selected_var_index_pred, :, :]
title_pred = f"Predicted Surface: {pred_var}"
# Plot predicted data
fig_pred, ax_pred = plt.subplots(figsize=(10, 5), subplot_kw={'projection': ccrs.PlateCarree()})
ax_pred.set_title(title_pred)
im_pred = ax_pred.imshow(data_to_plot_pred, extent=[lon.min(), lon.max(), lat.min(), lat.max()],
origin='lower', cmap='coolwarm', transform=ccrs.PlateCarree())
ax_pred.coastlines()
plt.colorbar(im_pred, ax=ax_pred, orientation='horizontal', pad=0.05)
st.pyplot(fig_pred)
# --- Download Buttons ---
st.subheader("Download Predicted Data")
# Convert out_upper and out_surface to binary format for download
buffer_upper = io.BytesIO()
np.save(buffer_upper, out_upper)
buffer_upper.seek(0)
buffer_surface = io.BytesIO()
np.save(buffer_surface, out_surface)
buffer_surface.seek(0)
st.download_button(label="Download Predicted Upper-Air Data",
data=buffer_upper,
file_name="predicted_upper.npy",
mime="application/octet-stream")
st.download_button(label="Download Predicted Surface Data",
data=buffer_surface,
file_name="predicted_surface.npy",
mime="application/octet-stream")