Spaces:
Sleeping
Sleeping
import streamlit as st | |
import torch | |
import random | |
import numpy as np | |
import yaml | |
import logging | |
import os | |
import matplotlib.pyplot as plt | |
from pathlib import Path | |
import tempfile | |
import traceback | |
from data_utils import ( | |
save_uploaded_files, | |
load_dataset, | |
) | |
from inference_utils import run_inference | |
from config_utils import load_config | |
from plot_utils import plot_prithvi_output, plot_aurora_output | |
from prithvi_utils import ( | |
prithvi_config_ui, | |
initialize_prithvi_model, | |
prepare_prithvi_batch | |
) | |
from aurora_utils import aurora_config_ui, prepare_aurora_batch, initialize_aurora_model | |
from pangu_utils import ( | |
pangu_config_data, | |
inference_1hr, | |
inference_3hrs, | |
inference_6hrs, | |
inference_24hrs, | |
inference_custom_hrs, | |
plot_pangu_output, | |
) | |
from fengwu_utils import (fengwu_config_data, inference_6hrs_fengwu, inference_12hrs_fengwu, inference_custom_hrs_fengwu, plot_fengwu_output) | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Set page configuration | |
st.set_page_config( | |
page_title="Weather Data Processor", | |
layout="wide", | |
initial_sidebar_state="expanded", | |
) | |
header_col1, header_col2 = st.columns([4, 1]) | |
with header_col1: | |
st.title("π¦οΈ Weather & Climate Data Processor and Forecaster") | |
with header_col2: | |
st.markdown("### Select a Model") | |
selected_model = st.selectbox( | |
"", | |
options=["Pangu-Weather", "FengWu", "Aurora", "Climax", "Prithvi", "LSTM"], | |
index=0, | |
key="model_selector", | |
help="Select the model you want to use." | |
) | |
st.write("---") | |
# --- Layout: Two Columns --- | |
left_col, right_col = st.columns([1, 2]) | |
with left_col: | |
st.header("π§ Configuration") | |
# Dynamically show configuration UI based on selected model | |
if selected_model == "Prithvi": | |
(config, uploaded_surface_files, uploaded_vertical_files, | |
clim_surf_path, clim_vert_path, config_path, weights_path) = prithvi_config_ui() | |
elif selected_model == "Aurora": | |
uploaded_files = aurora_config_ui() | |
elif selected_model == "Pangu-Weather": | |
input_surface_file, input_upper_file = pangu_config_data() | |
elif selected_model == "FengWu": | |
input_file1_fengwu, input_file2_fengwu = fengwu_config_data() | |
else: | |
# Generic data upload for other models | |
st.subheader(f"{selected_model} Model Data Upload") | |
st.markdown("### Drag and Drop Your Data Files Here") | |
uploaded_files = st.file_uploader( | |
f"Upload Data Files for {selected_model}", | |
accept_multiple_files=True, | |
key=f"{selected_model.lower()}_uploader", | |
type=["nc", "netcdf", "nc4"], | |
) | |
st.write("---") | |
# --- Forecast Duration Selection --- | |
st.subheader("Forecast Duration") | |
forecast_options = ["1 hour", "3 hours", "6 hours", "24 hours", "Custom"] | |
selected_duration = st.selectbox( | |
"Select forecast duration", | |
forecast_options, | |
index=3, # Default to 24 hours | |
help="Select how many hours to forecast." | |
) | |
custom_hours = None | |
if selected_duration == "Custom": | |
custom_hours = st.number_input( | |
"Enter custom forecast hours", | |
min_value=24, | |
max_value=480, | |
value=48, | |
step=24, | |
help="Enter the number of hours you want to forecast." | |
) | |
st.write("---") | |
# Run Inference button | |
if st.button("π Run Inference"): | |
with right_col: | |
st.header("π Inference Progress & Visualization") | |
# Set seeds and device | |
try: | |
torch.jit.enable_onednn_fusion(True) | |
if torch.cuda.is_available(): | |
device = torch.device("cuda") | |
st.write(f"Using device: **{torch.cuda.get_device_name()}**") | |
torch.backends.cudnn.benchmark = True | |
torch.backends.cudnn.deterministic = True | |
else: | |
device = torch.device("cpu") | |
st.write("Using device: **CPU**") | |
random.seed(42) | |
if torch.cuda.is_available(): | |
torch.cuda.manual_seed(42) | |
torch.manual_seed(42) | |
np.random.seed(42) | |
except Exception: | |
st.error("Error initializing device:") | |
st.error(traceback.format_exc()) | |
st.stop() | |
# Use a spinner while running inference | |
with st.spinner("Running inference, please wait..."): | |
# Initialize and run inference for selected model | |
if selected_model == "Prithvi": | |
model, in_mu, in_sig, output_sig, static_mu, static_sig = initialize_prithvi_model( | |
config, config_path, weights_path, device | |
) | |
batch = prepare_prithvi_batch( | |
uploaded_surface_files, uploaded_vertical_files, clim_surf_path, clim_vert_path, device | |
) | |
out = run_inference(selected_model, model, batch, device) | |
# Store results | |
st.session_state['prithvi_out'] = out | |
st.session_state['prithvi_done'] = True | |
elif selected_model == "Aurora": | |
if uploaded_files: | |
save_uploaded_files(uploaded_files) | |
ds = load_dataset(st.session_state.temp_file_paths) | |
if ds is not None: | |
batch = prepare_aurora_batch(ds) | |
model = initialize_aurora_model(device) | |
out = run_inference(selected_model, model, batch, device) | |
st.session_state['aurora_out'] = out | |
st.session_state['aurora_ds_subset'] = ds | |
st.session_state['aurora_done'] = True | |
else: | |
st.error("Failed to load dataset for Aurora.") | |
st.stop() | |
else: | |
st.error("Please upload data files for Aurora.") | |
st.stop() | |
elif selected_model == "FengWu": | |
if input_file1_fengwu and input_file2_fengwu: | |
try: | |
input1 = np.load(input_file1_fengwu) | |
input2 = np.load(input_file2_fengwu) | |
if selected_duration == "1 hour": | |
st.warning("1hr inference is not yet available on this model.") | |
elif selected_duration == "3 hours": | |
st.warning("3hrs inference is not yet available on this model.") | |
elif selected_duration == "6 hours": | |
output_fengwu = inference_6hrs_fengwu(input1, input2) | |
elif selected_duration == "12 hours": | |
output_fengwu = inference_12hrs_fengwu(input1, input2) | |
else: | |
output_fengwu = inference_custom_hrs_fengwu(input1, input2, custom_hours) | |
st.session_state['output_fengwu'] = output_fengwu | |
st.session_state['fengwu_done'] = True | |
st.session_state['input_fengwu'] = input_file2_fengwu | |
except Exception as e: | |
st.error(f"An error occurred: {e}") | |
else: | |
st.error("Please upload data files for Aurora.") | |
st.stop() | |
elif selected_model == "Pangu-Weather": | |
if input_surface_file and input_upper_file: | |
try: | |
surface_data = np.load(input_surface_file) | |
upper_data = np.load(input_upper_file) | |
# Decide which inference function to use based on selection | |
if selected_duration == "1 hour": | |
out_upper, out_surface = inference_1hr(upper_data, surface_data) | |
elif selected_duration == "3 hours": | |
out_upper, out_surface = inference_3hrs(upper_data, surface_data) | |
elif selected_duration == "6 hours": | |
out_upper, out_surface = inference_6hrs(upper_data, surface_data) | |
elif selected_duration == "24 hours": | |
out_upper, out_surface = inference_24hrs(upper_data, surface_data) | |
else: | |
out_upper, out_surface = inference_custom_hrs(upper_data, surface_data, custom_hours) | |
# Store results in session_state | |
st.session_state['pangu_upper_data'] = upper_data | |
st.session_state['pangu_surface_data'] = surface_data | |
st.session_state['pangu_out_upper'] = out_upper | |
st.session_state['pangu_out_surface'] = out_surface | |
st.session_state['pangu_done'] = True | |
st.write("**Forecast Results:**") | |
st.write("Upper Data Forecast Shape:", out_upper.shape) | |
st.write("Surface Data Forecast Shape:", out_surface.shape) | |
except Exception as e: | |
st.error(f"An error occurred: {e}") | |
else: | |
st.error("Please upload data files for Pangu-Weather.") | |
st.stop() | |
else: | |
st.warning("Inference not implemented for this model.") | |
st.stop() | |
# Visualization after inference is done | |
if selected_model == "Prithvi": | |
if 'prithvi_done' in st.session_state and st.session_state['prithvi_done']: | |
plot_prithvi_output(st.session_state['prithvi_out']) | |
elif selected_model == "Aurora": | |
if 'aurora_done' in st.session_state and st.session_state['aurora_done']: | |
plot_aurora_output(st.session_state['aurora_out'], st.session_state['aurora_ds_subset']) | |
elif selected_model == "FengWu": | |
if 'fengwu_done' in st.session_state and st.session_state['fengwu_done']: | |
plot_fengwu_output(st.session_state['input_fengwu'], st.session_state['output_fengwu']) | |
elif selected_model == "Pangu-Weather": | |
if 'pangu_done' in st.session_state and st.session_state['pangu_done']: | |
plot_pangu_output( | |
st.session_state['pangu_upper_data'], | |
st.session_state['pangu_surface_data'], | |
st.session_state['pangu_out_upper'], | |
st.session_state['pangu_out_surface'] | |
) | |
else: | |
st.info("No visualization implemented for this model.") | |
else: | |
# If not running inference now, but we have previously computed results, show them | |
with right_col: | |
st.header("π₯οΈ Visualization & Progress") | |
# Check which model was selected and if we have done inference before | |
if selected_model == "Prithvi" and 'prithvi_done' in st.session_state and st.session_state['prithvi_done']: | |
plot_prithvi_output(st.session_state['prithvi_out']) | |
elif selected_model == "Aurora" and 'aurora_done' in st.session_state and st.session_state['aurora_done']: | |
plot_aurora_output(st.session_state['aurora_out'], st.session_state['aurora_ds_subset']) | |
elif selected_model == "Pangu-Weather" and 'pangu_done' in st.session_state and st.session_state['pangu_done']: | |
plot_pangu_output( | |
st.session_state['pangu_upper_data'], | |
st.session_state['pangu_surface_data'], | |
st.session_state['pangu_out_upper'], | |
st.session_state['pangu_out_surface'] | |
) | |
elif selected_model == "FengWu" and 'output_fengwu' in st.session_state and st.session_state['fengwu_done']: | |
plot_fengwu_output(st.session_state['input_fengwu'], st.session_state['output_fengwu']) | |
else: | |
st.info("Awaiting inference to display results.") | |