Spaces:
Sleeping
Sleeping
import streamlit as st | |
import torch | |
import random | |
import numpy as np | |
import yaml | |
from pathlib import Path | |
import tempfile | |
import traceback | |
import matplotlib.pyplot as plt | |
import plotly.graph_objects as go | |
from Prithvi import * # Ensure this import includes your model and dataset classes | |
import xarray as xr | |
from aurora import Batch, Metadata | |
from aurora import Aurora, rollout | |
import logging | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import cartopy.crs as ccrs | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Function to save uploaded files to temporary files and store paths in session_state | |
def save_uploaded_files(uploaded_files): | |
if 'temp_file_paths' not in st.session_state: | |
st.session_state.temp_file_paths = [] | |
for uploaded_file in uploaded_files: | |
suffix = os.path.splitext(uploaded_file.name)[1] | |
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=suffix) | |
temp_file.write(uploaded_file.read()) | |
temp_file.close() | |
st.session_state.temp_file_paths.append(temp_file.name) | |
# Cached function to load dataset | |
def load_dataset(file_paths): | |
try: | |
ds = xr.open_mfdataset(file_paths, combine='by_coords').load() | |
return ds | |
except Exception as e: | |
st.error("Error loading dataset:") | |
st.error(traceback.format_exc()) | |
return None | |
# Set page configuration | |
st.set_page_config( | |
page_title="Weather Data Processor", | |
layout="wide", | |
initial_sidebar_state="expanded", | |
) | |
# Create a header with two columns: one for the title and one for the model selector | |
header_col1, header_col2 = st.columns([4, 1]) # Adjust the ratio as needed | |
with header_col1: | |
st.title("🌦️ Weather & Climate Data Processor and Forecaster") | |
with header_col2: | |
st.markdown("### Select a Model") | |
selected_model = st.selectbox( | |
"", | |
options=["Aurora", "Climax", "Prithvi", "LSTM"], | |
index=0, | |
key="model_selector", | |
help="Select the model you want to use for processing the data." | |
) | |
st.write("---") # Horizontal separator | |
# --- Layout: Two Columns --- | |
left_col, right_col = st.columns([1, 2]) # Adjust column ratios as needed | |
with left_col: | |
st.header("🔧 Configuration") | |
# --- Dynamic Configuration Based on Selected Model --- | |
def get_model_configuration(model_name): | |
if model_name == "Prithvi": | |
st.subheader("Prithvi Model Configuration") | |
# Prithvi-specific configuration inputs | |
param1 = st.number_input("Prithvi Parameter 1", value=10, step=1) | |
param2 = st.text_input("Prithvi Parameter 2", value="default_prithvi") | |
# Add other Prithvi-specific parameters here | |
config = { | |
"param1": param1, | |
"param2": param2, | |
# Include other parameters as needed | |
} | |
# --- Prithvi-Specific File Uploads --- | |
st.markdown("### Upload Data Files for Prithvi Model") | |
# File uploader for surface data | |
uploaded_surface_files = st.file_uploader( | |
"Upload Surface Data Files", | |
type=["nc", "netcdf"], | |
accept_multiple_files=True, | |
key="surface_uploader", | |
) | |
# File uploader for vertical data | |
uploaded_vertical_files = st.file_uploader( | |
"Upload Vertical Data Files", | |
type=["nc", "netcdf"], | |
accept_multiple_files=True, | |
key="vertical_uploader", | |
) | |
# Handle Climatology Files | |
st.markdown("### Upload Climatology Files (If Missing)") | |
# Climatology files paths | |
default_clim_dir = Path("Prithvi-WxC/examples/climatology") | |
surf_in_scal_path = default_clim_dir / "musigma_surface.nc" | |
vert_in_scal_path = default_clim_dir / "musigma_vertical.nc" | |
surf_out_scal_path = default_clim_dir / "anomaly_variance_surface.nc" | |
vert_out_scal_path = default_clim_dir / "anomaly_variance_vertical.nc" | |
# Check if climatology files exist | |
clim_files_exist = all( | |
[ | |
surf_in_scal_path.exists(), | |
vert_in_scal_path.exists(), | |
surf_out_scal_path.exists(), | |
vert_out_scal_path.exists(), | |
] | |
) | |
if not clim_files_exist: | |
st.warning("Climatology files are missing.") | |
uploaded_clim_surface = st.file_uploader( | |
"Upload Climatology Surface File", | |
type=["nc", "netcdf"], | |
key="clim_surface_uploader", | |
) | |
uploaded_clim_vertical = st.file_uploader( | |
"Upload Climatology Vertical File", | |
type=["nc", "netcdf"], | |
key="clim_vertical_uploader", | |
) | |
# Process uploaded climatology files | |
if uploaded_clim_surface and uploaded_clim_vertical: | |
clim_temp_dir = tempfile.mkdtemp() | |
clim_surf_path = Path(clim_temp_dir) / uploaded_clim_surface.name | |
with open(clim_surf_path, "wb") as f: | |
f.write(uploaded_clim_surface.getbuffer()) | |
clim_vert_path = Path(clim_temp_dir) / uploaded_clim_vertical.name | |
with open(clim_vert_path, "wb") as f: | |
f.write(uploaded_clim_vertical.getbuffer()) | |
st.success("Climatology files uploaded and saved.") | |
else: | |
st.warning("Please upload both climatology surface and vertical files.") | |
else: | |
clim_surf_path = surf_in_scal_path | |
clim_vert_path = vert_in_scal_path | |
# Optional: Upload config.yaml | |
uploaded_config = st.file_uploader( | |
"Upload config.yaml", | |
type=["yaml", "yml"], | |
key="config_uploader", | |
) | |
if uploaded_config: | |
temp_config = tempfile.mktemp(suffix=".yaml") | |
with open(temp_config, "wb") as f: | |
f.write(uploaded_config.getbuffer()) | |
config_path = Path(temp_config) | |
st.success("Config.yaml uploaded and saved.") | |
else: | |
# Use default config.yaml path | |
config_path = Path("Prithvi-WxC/examples/config.yaml") | |
if not config_path.exists(): | |
st.error("Default config.yaml not found. Please upload a config file.") | |
st.stop() | |
# Optional: Upload model weights | |
uploaded_weights = st.file_uploader( | |
"Upload Model Weights (.pt)", | |
type=["pt"], | |
key="weights_uploader", | |
) | |
if uploaded_weights: | |
temp_weights = tempfile.mktemp(suffix=".pt") | |
with open(temp_weights, "wb") as f: | |
f.write(uploaded_weights.getbuffer()) | |
weights_path = Path(temp_weights) | |
st.success("Model weights uploaded and saved.") | |
else: | |
# Use default weights path | |
weights_path = Path("Prithvi-WxC/examples/weights/prithvi.wxc.2300m.v1.pt") | |
if not weights_path.exists(): | |
st.error("Default model weights not found. Please upload model weights.") | |
st.stop() | |
return config, uploaded_surface_files, uploaded_vertical_files, clim_surf_path, clim_vert_path, config_path, weights_path | |
else: | |
# For other models, provide a simple file uploader | |
st.subheader(f"{model_name} Model Data Upload") | |
st.markdown("### Drag and Drop Your Data Files Here") | |
uploaded_files = st.file_uploader( | |
f"Upload Data Files for {model_name}", | |
accept_multiple_files=True, | |
key=f"{model_name.lower()}_uploader", | |
type=["nc", "netcdf", "nc4"], | |
) | |
return uploaded_files | |
# Retrieve model-specific configuration and files | |
if selected_model == "Prithvi": | |
config, uploaded_surface_files, uploaded_vertical_files, clim_surf_path, clim_vert_path, config_path, weights_path = get_model_configuration(selected_model) | |
else: | |
uploaded_files = get_model_configuration(selected_model) | |
st.write("---") # Horizontal separator | |
# --- Run Inference Button --- | |
if st.button("🚀 Run Inference"): | |
with right_col: | |
st.header("📈 Inference Progress & Visualization") | |
# Initialize device | |
try: | |
torch.jit.enable_onednn_fusion(True) | |
if torch.cuda.is_available(): | |
device = torch.device("cuda") | |
st.write(f"Using device: **{torch.cuda.get_device_name()}**") | |
torch.backends.cudnn.benchmark = True | |
torch.backends.cudnn.deterministic = True | |
else: | |
device = torch.device("cpu") | |
st.write("Using device: **CPU**") | |
except Exception as e: | |
st.error("Error initializing device:") | |
st.error(traceback.format_exc()) | |
st.stop() | |
# Set random seeds | |
try: | |
random.seed(42) | |
if torch.cuda.is_available(): | |
torch.cuda.manual_seed(42) | |
torch.manual_seed(42) | |
np.random.seed(42) | |
except Exception as e: | |
st.error("Error setting random seeds:") | |
st.error(traceback.format_exc()) | |
st.stop() | |
# # Define variables and parameters based on dataset type | |
# if dataset_type == "MERRA2": | |
# surface_vars = [ | |
# "EFLUX", | |
# "GWETROOT", | |
# "HFLUX", | |
# "LAI", | |
# "LWGAB", | |
# "LWGEM", | |
# "LWTUP", | |
# "PS", | |
# "QV2M", | |
# "SLP", | |
# "SWGNT", | |
# "SWTNT", | |
# "T2M", | |
# "TQI", | |
# "TQL", | |
# "TQV", | |
# "TS", | |
# "U10M", | |
# "V10M", | |
# "Z0M", | |
# ] | |
# static_surface_vars = ["FRACI", "FRLAND", "FROCEAN", "PHIS"] | |
# vertical_vars = ["CLOUD", "H", "OMEGA", "PL", "QI", "QL", "QV", "T", "U", "V"] | |
# levels = [ | |
# 34.0, | |
# 39.0, | |
# 41.0, | |
# 43.0, | |
# 44.0, | |
# 45.0, | |
# 48.0, | |
# 51.0, | |
# 53.0, | |
# 56.0, | |
# 63.0, | |
# 68.0, | |
# 71.0, | |
# 72.0, | |
# ] | |
# elif dataset_type == "GEOS5": | |
# # Define GEOS5 specific variables | |
# surface_vars = [ | |
# "GEOS5_EFLUX", | |
# "GEOS5_GWETROOT", | |
# "GEOS5_HFLUX", | |
# "GEOS5_LAI", | |
# "GEOS5_LWGAB", | |
# "GEOS5_LWGEM", | |
# "GEOS5_LWTUP", | |
# "GEOS5_PS", | |
# "GEOS5_QV2M", | |
# "GEOS5_SLP", | |
# "GEOS5_SWGNT", | |
# "GEOS5_SWTNT", | |
# "GEOS5_T2M", | |
# "GEOS5_TQI", | |
# "GEOS5_TQL", | |
# "GEOS5_TQV", | |
# "GEOS5_TS", | |
# "GEOS5_U10M", | |
# "GEOS5_V10M", | |
# "GEOS5_Z0M", | |
# ] | |
# static_surface_vars = ["GEOS5_FRACI", "GEOS5_FRLAND", "GEOS5_FROCEAN", "GEOS5_PHIS"] | |
# vertical_vars = ["GEOS5_CLOUD", "GEOS5_H", "GEOS5_OMEGA", "GEOS5_PL", "GEOS5_QI", "GEOS5_QL", "GEOS5_QV", "GEOS5_T", "GEOS5_U", "GEOS5_V"] | |
# levels = [ | |
# # Define levels specific to GEOS5 if different | |
# 10.0, | |
# 20.0, | |
# 30.0, | |
# 40.0, | |
# 50.0, | |
# 60.0, | |
# 70.0, | |
# 80.0, | |
# ] | |
# else: | |
# st.error("Unsupported dataset type selected.") | |
# st.stop() | |
padding = {"level": [0, 0], "lat": [0, -1], "lon": [0, 0]} | |
residual = "climate" | |
masking_mode = "local" | |
decoder_shifting = True | |
masking_ratio = 0.99 | |
positional_encoding = "fourier" | |
# --- Initialize Dataset --- | |
try: | |
with st.spinner("Initializing dataset..."): | |
if selected_model == "Prithvi": | |
pass | |
# # Validate climatology files | |
# if not clim_files_exist and not (uploaded_clim_surface and uploaded_clim_vertical): | |
# st.error("Climatology files are missing. Please upload both climatology surface and vertical files.") | |
# st.stop() | |
# dataset = Merra2Dataset( | |
# time_range=time_range, | |
# lead_times=lead_times, | |
# input_times=input_times, | |
# data_path_surface=surf_dir, | |
# data_path_vertical=vert_dir, | |
# climatology_path_surface=clim_surf_path, | |
# climatology_path_vertical=clim_vert_path, | |
# surface_vars=surface_vars, | |
# static_surface_vars=static_surface_vars, | |
# vertical_vars=vertical_vars, | |
# levels=levels, | |
# positional_encoding=positional_encoding, | |
# ) | |
# assert len(dataset) > 0, "There doesn't seem to be any valid data." | |
elif selected_model == "Aurora": | |
# TODO just temporary, replace this | |
if uploaded_files: | |
temp_file_paths = [] # List to store paths of temporary files | |
try: | |
# Save each uploaded file to a temporary file | |
save_uploaded_files(uploaded_files) | |
ds = load_dataset(st.session_state.temp_file_paths) | |
# Now, use xarray to open the multiple files | |
if ds: | |
st.success("Files successfully loaded!") | |
st.session_state.ds_subset = ds | |
# print(ds) | |
ds = ds.fillna(ds.mean()) | |
desired_levels = [50, 100, 150, 200, 250, 300, 400, 500, 600, 700, 850, 925, 1000] | |
# Ensure that the 'lev' dimension exists | |
if 'lev' not in ds.dims: | |
raise ValueError("The dataset does not contain a 'lev' (pressure level) dimension.") | |
# Define the _prepare function | |
def _prepare(x: np.ndarray, i: int) -> torch.Tensor: | |
# Select previous and current time steps | |
selected = x[[i - 6, i]] | |
# Add a batch dimension | |
selected = selected[None] | |
# Ensure data is contiguous | |
selected = selected.copy() | |
# Convert to PyTorch tensor | |
return torch.from_numpy(selected) | |
# Adjust latitudes and longitudes | |
lat = ds.lat.values * -1 | |
lon = ds.lon.values + 180 | |
# Subset the dataset to only include the desired pressure levels | |
ds_subset = ds.sel(lev=desired_levels, method="nearest") | |
# Verify that all desired levels are present | |
present_levels = ds_subset.lev.values | |
missing_levels = set(desired_levels) - set(present_levels) | |
if missing_levels: | |
raise ValueError(f"The following desired pressure levels are missing in the dataset: {missing_levels}") | |
# Extract pressure levels after subsetting | |
lev = ds_subset.lev.values # Pressure levels in hPa | |
# Prepare surface variables at 1000 hPa | |
try: | |
lev_index_1000 = np.where(lev == 1000)[0][0] | |
except IndexError: | |
raise ValueError("1000 hPa level not found in the 'lev' dimension after subsetting.") | |
T_surface = ds_subset.T.isel(lev=lev_index_1000).compute() | |
U_surface = ds_subset.U.isel(lev=lev_index_1000).compute() | |
V_surface = ds_subset.V.isel(lev=lev_index_1000).compute() | |
SLP = ds_subset.SLP.compute() | |
# Reorder static variables (selecting the first time index to remove the time dimension) | |
PHIS = ds_subset.PHIS.isel(time=0).compute() | |
# Prepare atmospheric variables for the desired pressure levels excluding 1000 hPa | |
atmos_levels = [int(level) for level in lev if level != 1000] | |
T_atm = (ds_subset.T.sel(lev=atmos_levels)).compute() | |
U_atm = (ds_subset.U.sel(lev=atmos_levels)).compute() | |
V_atm = (ds_subset.V.sel(lev=atmos_levels)).compute() | |
# Select time index | |
num_times = ds_subset.time.size | |
i = 6 # Adjust as needed (1 <= i < num_times) | |
if i >= num_times or i < 1: | |
raise IndexError("Time index i is out of bounds.") | |
time_values = ds_subset.time.values | |
current_time = np.datetime64(time_values[i]).astype('datetime64[s]').astype(datetime) | |
# Prepare surface variables | |
surf_vars = { | |
"2t": _prepare(T_surface.values, i), # Two-meter temperature | |
"10u": _prepare(U_surface.values, i), # Ten-meter eastward wind | |
"10v": _prepare(V_surface.values, i), # Ten-meter northward wind | |
"msl": _prepare(SLP.values, i), # Mean sea-level pressure | |
} | |
# Prepare static variables (now 2D tensors) | |
static_vars = { | |
"z": torch.from_numpy(PHIS.values.copy()), # Geopotential (h, w) | |
# Add 'lsm' and 'slt' if available and needed | |
} | |
# Prepare atmospheric variables | |
atmos_vars = { | |
"t": _prepare(T_atm.values, i), # Temperature at desired levels | |
"u": _prepare(U_atm.values, i), # Eastward wind at desired levels | |
"v": _prepare(V_atm.values, i), # Southward wind at desired levels | |
} | |
# Define metadata | |
metadata = Metadata( | |
lat=torch.from_numpy(lat.copy()), | |
lon=torch.from_numpy(lon.copy()), | |
time=(current_time,), | |
atmos_levels=tuple(atmos_levels), # Only the desired atmospheric levels | |
) | |
# Create the Batch object | |
batch = Batch( | |
surf_vars=surf_vars, | |
static_vars=static_vars, | |
atmos_vars=atmos_vars, | |
metadata=metadata | |
) # Display the dataset or perform further processing | |
st.session_state['batch'] = batch | |
except Exception as e: | |
st.error(f"An error occurred: {e}") | |
# finally: | |
# # Clean up: Remove temporary files | |
# for path in temp_file_paths: | |
# try: | |
# os.remove(path) | |
# except Exception as e: | |
# st.warning(f"Could not delete temp file {path}: {e}") | |
else: | |
# For other models, implement their specific dataset initialization | |
# Placeholder: Replace with actual dataset initialization for other models | |
dataset = None # Replace with actual dataset | |
st.warning("Dataset initialization for this model is not implemented yet.") | |
st.stop() | |
st.success("Dataset initialized successfully.") | |
except Exception as e: | |
st.error("Error initializing dataset:") | |
st.error(traceback.format_exc()) | |
st.stop() | |
# --- Load Scalers --- | |
try: | |
with st.spinner("Loading scalers..."): | |
if selected_model == "Prithvi": | |
pass | |
# # Assuming the scaler paths are the same as climatology paths | |
# surf_in_scal_path = clim_surf_path | |
# vert_in_scal_path = clim_vert_path | |
# surf_out_scal_path = Path(clim_surf_path.parent) / "anomaly_variance_surface.nc" | |
# vert_out_scal_path = Path(clim_vert_path.parent) / "anomaly_variance_vertical.nc" | |
# # Check if output scaler files exist | |
# if not surf_out_scal_path.exists() or not vert_out_scal_path.exists(): | |
# st.error("Anomaly variance scaler files are missing.") | |
# st.stop() | |
# in_mu, in_sig = input_scalers( | |
# surface_vars, | |
# vertical_vars, | |
# levels, | |
# surf_in_scal_path, | |
# vert_in_scal_path, | |
# ) | |
# output_sig = output_scalers( | |
# surface_vars, | |
# vertical_vars, | |
# levels, | |
# surf_out_scal_path, | |
# vert_out_scal_path, | |
# ) | |
# static_mu, static_sig = static_input_scalers( | |
# surf_in_scal_path, | |
# static_surface_vars, | |
# ) | |
else: | |
# Load scalers for other models if applicable | |
# Placeholder: Replace with actual scaler loading for other models | |
in_mu, in_sig = None, None | |
output_sig = None | |
static_mu, static_sig = None, None | |
st.success("Scalers loaded successfully.") | |
except Exception as e: | |
st.error("Error loading scalers:") | |
st.error(traceback.format_exc()) | |
st.stop() | |
# --- Load Configuration --- | |
try: | |
with st.spinner("Loading configuration..."): | |
if selected_model == "Prithvi": | |
with open(config_path, "r") as f: | |
config = yaml.safe_load(f) | |
# Validate config | |
required_params = [ | |
"in_channels", "input_size_time", "in_channels_static", | |
"input_scalers_epsilon", "static_input_scalers_epsilon", | |
"n_lats_px", "n_lons_px", "patch_size_px", | |
"mask_unit_size_px", "embed_dim", "n_blocks_encoder", | |
"n_blocks_decoder", "mlp_multiplier", "n_heads", | |
"dropout", "drop_path", "parameter_dropout" | |
] | |
missing_params = [param for param in required_params if param not in config.get("params", {})] | |
if missing_params: | |
st.error(f"Missing configuration parameters: {missing_params}") | |
st.stop() | |
else: | |
# Load configuration for other models if applicable | |
# Placeholder: Replace with actual configuration loading for other models | |
config = {} | |
st.success("Configuration loaded successfully.") | |
except Exception as e: | |
st.error("Error loading configuration:") | |
st.error(traceback.format_exc()) | |
st.stop() | |
# --- Initialize the Model --- | |
try: | |
with st.spinner("Initializing model..."): | |
if selected_model == "Prithvi": | |
model = PrithviWxC( | |
in_channels=config["params"]["in_channels"], | |
input_size_time=config["params"]["input_size_time"], | |
in_channels_static=config["params"]["in_channels_static"], | |
input_scalers_mu=in_mu, | |
input_scalers_sigma=in_sig, | |
input_scalers_epsilon=config["params"]["input_scalers_epsilon"], | |
static_input_scalers_mu=static_mu, | |
static_input_scalers_sigma=static_sig, | |
static_input_scalers_epsilon=config["params"]["static_input_scalers_epsilon"], | |
output_scalers=output_sig**0.5, | |
n_lats_px=config["params"]["n_lats_px"], | |
n_lons_px=config["params"]["n_lons_px"], | |
patch_size_px=config["params"]["patch_size_px"], | |
mask_unit_size_px=config["params"]["mask_unit_size_px"], | |
mask_ratio_inputs=masking_ratio, | |
embed_dim=config["params"]["embed_dim"], | |
n_blocks_encoder=config["params"]["n_blocks_encoder"], | |
n_blocks_decoder=config["params"]["n_blocks_decoder"], | |
mlp_multiplier=config["params"]["mlp_multiplier"], | |
n_heads=config["params"]["n_heads"], | |
dropout=config["params"]["dropout"], | |
drop_path=config["params"]["drop_path"], | |
parameter_dropout=config["params"]["parameter_dropout"], | |
residual=residual, | |
masking_mode=masking_mode, | |
decoder_shifting=decoder_shifting, | |
positional_encoding=positional_encoding, | |
checkpoint_encoder=[], | |
checkpoint_decoder=[], | |
) | |
elif selected_model == "Aurora": | |
pass | |
else: | |
# Initialize other models here | |
# Placeholder: Replace with actual model initialization for other models | |
model = None | |
st.warning("Model initialization for this model is not implemented yet.") | |
st.stop() | |
# model.to(device) | |
st.success("Model initialized successfully.") | |
except Exception as e: | |
st.error("Error initializing model:") | |
st.error(traceback.format_exc()) | |
st.stop() | |
# --- Load Model Weights --- | |
try: | |
with st.spinner("Loading model weights..."): | |
if selected_model == "Prithvi": | |
state_dict = torch.load(weights_path, map_location=device) | |
if "model_state" in state_dict: | |
state_dict = state_dict["model_state"] | |
model.load_state_dict(state_dict, strict=True) | |
model.to(device) | |
else: | |
# Load weights for other models if applicable | |
# Placeholder: Replace with actual weight loading for other models | |
pass | |
st.success("Model weights loaded successfully.") | |
except Exception as e: | |
st.error("Error loading model weights:") | |
st.error(traceback.format_exc()) | |
st.stop() | |
# --- Prepare Data Batch --- | |
try: | |
with st.spinner("Preparing data batch..."): | |
if selected_model == "Prithvi": | |
data = next(iter(dataset)) | |
batch = preproc([data], padding) | |
for k, v in batch.items(): | |
if isinstance(v, torch.Tensor): | |
batch[k] = v.to(device) | |
elif selected_model == "Aurora": | |
batch = batch.regrid(res=0.25) | |
else: | |
# Prepare data batch for other models | |
# Placeholder: Replace with actual data preparation for other models | |
batch = None | |
st.success("Data batch prepared successfully.") | |
except Exception as e: | |
st.error("Error preparing data batch:") | |
st.error(traceback.format_exc()) | |
st.stop() | |
# --- Run Inference --- | |
try: | |
with st.spinner("Running model inference..."): | |
if selected_model == "Prithvi": | |
model.eval() | |
with torch.no_grad(): | |
out = model(batch) | |
elif selected_model == "Aurora": | |
model = Aurora(use_lora=False) | |
# model = Aurora() | |
model.load_checkpoint("microsoft/aurora", "aurora-0.25-pretrained.ckpt") | |
# model.load_checkpoint("microsoft/aurora", "aurora-0.25-pretrained.ckpt") | |
model.eval() | |
# model = model.to("cuda") # Uncomment if using a GPU | |
with torch.inference_mode(): | |
out = [pred.to("cpu") for pred in rollout(model, batch, steps=2)] | |
model = model.to("cpu") | |
st.session_state.model = model | |
else: | |
# Run inference for other models | |
# Placeholder: Replace with actual inference code for other models | |
out = torch.randn(1, 10, 180, 360) # Dummy tensor | |
st.success("Model inference completed successfully.") | |
st.session_state['out'] = out | |
except Exception as e: | |
st.error("Error during model inference:") | |
st.error(traceback.format_exc()) | |
st.stop() | |
# --- Visualization Settings --- | |
st.markdown("## 📊 Visualization Settings") | |
if 'out' in st.session_state and 'batch' in st.session_state and selected_model == "Prithvi": | |
# Display the shape of the output tensor | |
out_tensor = st.session_state['out'] | |
st.write(f"**Output tensor shape:** {out_tensor.shape}") | |
# Ensure the output tensor has at least 4 dimensions (batch, variables, lat, lon) | |
if out_tensor.ndim < 4: | |
st.error("The output tensor does not have the expected number of dimensions (batch, variables, lat, lon).") | |
st.stop() | |
# Get the number of variables | |
num_variables = out_tensor.shape[1] | |
# Define variable names (update with your actual variable names) | |
variable_names = [f"Variable_{i}" for i in range(num_variables)] | |
# Visualization settings | |
col1, col2 = st.columns(2) | |
with col1: | |
# Select variable to plot | |
selected_variable_name = st.selectbox( | |
"Select Variable to Plot", | |
options=variable_names, | |
index=0, | |
help="Choose the variable you want to visualize." | |
) | |
# Select plot type | |
plot_type = st.selectbox( | |
"Select Plot Type", | |
options=["Contour", "Heatmap"], | |
index=0, | |
help="Choose the type of plot to display." | |
) | |
with col2: | |
# Select color map | |
cmap = st.selectbox( | |
"Select Color Map", | |
options=plt.colormaps(), | |
index=plt.colormaps().index("viridis"), | |
help="Choose the color map for the plot." | |
) | |
# Set number of levels (for contour plot) | |
if plot_type == "Contour": | |
num_levels = st.slider( | |
"Number of Contour Levels", | |
min_value=5, | |
max_value=100, | |
value=20, | |
step=5, | |
help="Set the number of contour levels." | |
) | |
else: | |
num_levels = None | |
# Find the index based on the selected name | |
variable_index = variable_names.index(selected_variable_name) | |
# Extract the selected variable | |
selected_variable = out_tensor[0, variable_index].cpu().numpy() | |
# Generate latitude and longitude arrays | |
lat = np.linspace(-90, 90, selected_variable.shape[0]) | |
lon = np.linspace(-180, 180, selected_variable.shape[1]) | |
X, Y = np.meshgrid(lon, lat) | |
# Plot the selected variable | |
st.markdown(f"### Plot of {selected_variable_name}") | |
# Matplotlib figure | |
fig, ax = plt.subplots(figsize=(10, 6)) | |
if plot_type == "Contour": | |
# Generate the contour plot | |
contour = ax.contourf(X, Y, selected_variable, levels=num_levels, cmap=cmap) | |
elif plot_type == "Heatmap": | |
# Generate the heatmap | |
contour = ax.imshow(selected_variable, extent=[-180, 180, -90, 90], cmap=cmap, origin='lower', aspect='auto') | |
# Add a color bar | |
cbar = plt.colorbar(contour, ax=ax) | |
cbar.set_label(f'{selected_variable_name}', fontsize=12) | |
# Set aspect ratio and labels | |
ax.set_xlabel("Longitude", fontsize=12) | |
ax.set_ylabel("Latitude", fontsize=12) | |
ax.set_title(f"{selected_variable_name}", fontsize=14) | |
# Display the plot in Streamlit | |
st.pyplot(fig) | |
# Optional: Provide interactive Plotly plot | |
st.markdown("#### Interactive Plot") | |
if plot_type == "Contour": | |
fig_plotly = go.Figure(data=go.Contour( | |
z=selected_variable, | |
x=lon, | |
y=lat, | |
colorscale=cmap, | |
contours=dict( | |
coloring='fill', | |
showlabels=True, | |
labelfont=dict(size=12, color='white'), | |
ncontours=num_levels | |
) | |
)) | |
elif plot_type == "Heatmap": | |
fig_plotly = go.Figure(data=go.Heatmap( | |
z=selected_variable, | |
x=lon, | |
y=lat, | |
colorscale=cmap | |
)) | |
fig_plotly.update_layout( | |
xaxis_title="Longitude", | |
yaxis_title="Latitude", | |
autosize=False, | |
width=800, | |
height=600, | |
) | |
st.plotly_chart(fig_plotly) | |
elif 'out' in st.session_state and selected_model == "Aurora" and st.session_state['out'] is not None: | |
preds = st.session_state['out'] | |
ds_subset = st.session_state.get('ds_subset', None) | |
batch = st.session_state.get('batch', None) | |
# **Determine Available Levels** | |
# For example, let's assume levels range from 0 to max_level_index | |
# You need to replace 'max_level_index' with the actual maximum level index in your data | |
try: | |
# Assuming 'lev' dimension exists and is 1D | |
levels = preds[0].atmos_vars["t"].shape[2] # Adjust based on your data structure | |
level_indices = list(range(levels)) | |
except Exception as e: | |
st.error("Error determining available levels:") | |
st.error(traceback.format_exc()) | |
levels = None # Set to None if levels cannot be determined | |
if levels is not None: | |
# **Add a Slider for Level Selection** | |
selected_level = st.slider( | |
'Select Level', | |
min_value=0, | |
max_value=levels - 1, | |
value=11, # Default level index | |
step=1, | |
help="Select the vertical level for plotting." | |
) | |
# Loop through predictions and ground truths | |
for idx in range(len(preds)): | |
pred = preds[idx] | |
pred_time = pred.metadata.time[0] | |
# Display prediction time | |
st.write(f"### Prediction Time: {pred_time}") | |
# **Extract Data at Selected Level** | |
try: | |
# Update indices with the selected_level | |
pred_data = pred.atmos_vars["t"][0][0][selected_level].numpy() - 273.15 | |
truth_data = ds_subset.T.isel(lev=selected_level)[idx].values - 273.15 | |
except Exception as e: | |
st.error("Error extracting data for plotting:") | |
st.error(traceback.format_exc()) | |
continue | |
# Extract latitude and longitude | |
try: | |
lat = np.array(pred.metadata.lat) # Assuming 'lat' is 1D | |
lon = np.array(pred.metadata.lon) # Assuming 'lon' is 1D | |
except Exception as e: | |
st.error("Error extracting latitude and longitude:") | |
st.error(traceback.format_exc()) | |
continue | |
# Create a meshgrid for plotting | |
lon_grid, lat_grid = np.meshgrid(lon, lat) | |
# Create a Matplotlib figure with Cartopy projection | |
fig, axes = plt.subplots( | |
1, 3, figsize=(18, 6), | |
subplot_kw={'projection': ccrs.PlateCarree()} | |
) | |
# **Ground Truth Plot** | |
im1 = axes[0].imshow( | |
truth_data, | |
extent=[lon.min(), lon.max(), lat.min(), lat.max()], | |
origin='lower', | |
cmap='coolwarm', | |
transform=ccrs.PlateCarree() | |
) | |
axes[0].set_title(f"Ground Truth at Level {selected_level} - {pred_time}") | |
axes[0].set_xlabel('Longitude') | |
axes[0].set_ylabel('Latitude') | |
plt.colorbar(im1, ax=axes[0], orientation='horizontal', pad=0.05) | |
# **Prediction Plot** | |
im2 = axes[1].imshow( | |
pred_data, | |
extent=[lon.min(), lon.max(), lat.min(), lat.max()], | |
origin='lower', | |
cmap='coolwarm', | |
transform=ccrs.PlateCarree() | |
) | |
axes[1].set_title(f"Prediction at Level {selected_level} - {pred_time}") | |
axes[1].set_xlabel('Longitude') | |
axes[1].set_ylabel('Latitude') | |
plt.colorbar(im2, ax=axes[1], orientation='horizontal', pad=0.05) | |
plt.tight_layout() | |
# Display the plot in Streamlit | |
st.pyplot(fig) | |
else: | |
st.error("Could not determine the available levels in the data.") | |
else: | |
st.warning("No output available to display or visualization is not implemented for this model.") | |
# --- End of Inference Button --- | |
else: | |
with right_col: | |
st.header("🖥️ Visualization & Progress") | |
st.info("Awaiting inference to display results.") | |