import streamlit as st import torch import random import numpy as np import yaml from pathlib import Path import tempfile import traceback import matplotlib.pyplot as plt import plotly.graph_objects as go from Prithvi import * # Ensure this import includes your model and dataset classes import xarray as xr from aurora import Batch, Metadata from aurora import Aurora, rollout import logging import matplotlib.pyplot as plt import numpy as np import cartopy.crs as ccrs logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Function to save uploaded files to temporary files and store paths in session_state def save_uploaded_files(uploaded_files): if 'temp_file_paths' not in st.session_state: st.session_state.temp_file_paths = [] for uploaded_file in uploaded_files: suffix = os.path.splitext(uploaded_file.name)[1] temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=suffix) temp_file.write(uploaded_file.read()) temp_file.close() st.session_state.temp_file_paths.append(temp_file.name) # Cached function to load dataset @st.cache_resource def load_dataset(file_paths): try: ds = xr.open_mfdataset(file_paths, combine='by_coords').load() return ds except Exception as e: st.error("Error loading dataset:") st.error(traceback.format_exc()) return None # Set page configuration st.set_page_config( page_title="Weather Data Processor", layout="wide", initial_sidebar_state="expanded", ) # Create a header with two columns: one for the title and one for the model selector header_col1, header_col2 = st.columns([4, 1]) # Adjust the ratio as needed with header_col1: st.title("🌦️ Weather & Climate Data Processor and Forecaster") with header_col2: st.markdown("### Select a Model") selected_model = st.selectbox( "", options=["Aurora", "Climax", "Prithvi", "LSTM"], index=0, key="model_selector", help="Select the model you want to use for processing the data." ) st.write("---") # Horizontal separator # --- Layout: Two Columns --- left_col, right_col = st.columns([1, 2]) # Adjust column ratios as needed with left_col: st.header("🔧 Configuration") # --- Dynamic Configuration Based on Selected Model --- def get_model_configuration(model_name): if model_name == "Prithvi": st.subheader("Prithvi Model Configuration") # Prithvi-specific configuration inputs param1 = st.number_input("Prithvi Parameter 1", value=10, step=1) param2 = st.text_input("Prithvi Parameter 2", value="default_prithvi") # Add other Prithvi-specific parameters here config = { "param1": param1, "param2": param2, # Include other parameters as needed } # --- Prithvi-Specific File Uploads --- st.markdown("### Upload Data Files for Prithvi Model") # File uploader for surface data uploaded_surface_files = st.file_uploader( "Upload Surface Data Files", type=["nc", "netcdf"], accept_multiple_files=True, key="surface_uploader", ) # File uploader for vertical data uploaded_vertical_files = st.file_uploader( "Upload Vertical Data Files", type=["nc", "netcdf"], accept_multiple_files=True, key="vertical_uploader", ) # Handle Climatology Files st.markdown("### Upload Climatology Files (If Missing)") # Climatology files paths default_clim_dir = Path("Prithvi-WxC/examples/climatology") surf_in_scal_path = default_clim_dir / "musigma_surface.nc" vert_in_scal_path = default_clim_dir / "musigma_vertical.nc" surf_out_scal_path = default_clim_dir / "anomaly_variance_surface.nc" vert_out_scal_path = default_clim_dir / "anomaly_variance_vertical.nc" # Check if climatology files exist clim_files_exist = all( [ surf_in_scal_path.exists(), vert_in_scal_path.exists(), surf_out_scal_path.exists(), vert_out_scal_path.exists(), ] ) if not clim_files_exist: st.warning("Climatology files are missing.") uploaded_clim_surface = st.file_uploader( "Upload Climatology Surface File", type=["nc", "netcdf"], key="clim_surface_uploader", ) uploaded_clim_vertical = st.file_uploader( "Upload Climatology Vertical File", type=["nc", "netcdf"], key="clim_vertical_uploader", ) # Process uploaded climatology files if uploaded_clim_surface and uploaded_clim_vertical: clim_temp_dir = tempfile.mkdtemp() clim_surf_path = Path(clim_temp_dir) / uploaded_clim_surface.name with open(clim_surf_path, "wb") as f: f.write(uploaded_clim_surface.getbuffer()) clim_vert_path = Path(clim_temp_dir) / uploaded_clim_vertical.name with open(clim_vert_path, "wb") as f: f.write(uploaded_clim_vertical.getbuffer()) st.success("Climatology files uploaded and saved.") else: st.warning("Please upload both climatology surface and vertical files.") else: clim_surf_path = surf_in_scal_path clim_vert_path = vert_in_scal_path # Optional: Upload config.yaml uploaded_config = st.file_uploader( "Upload config.yaml", type=["yaml", "yml"], key="config_uploader", ) if uploaded_config: temp_config = tempfile.mktemp(suffix=".yaml") with open(temp_config, "wb") as f: f.write(uploaded_config.getbuffer()) config_path = Path(temp_config) st.success("Config.yaml uploaded and saved.") else: # Use default config.yaml path config_path = Path("Prithvi-WxC/examples/config.yaml") if not config_path.exists(): st.error("Default config.yaml not found. Please upload a config file.") st.stop() # Optional: Upload model weights uploaded_weights = st.file_uploader( "Upload Model Weights (.pt)", type=["pt"], key="weights_uploader", ) if uploaded_weights: temp_weights = tempfile.mktemp(suffix=".pt") with open(temp_weights, "wb") as f: f.write(uploaded_weights.getbuffer()) weights_path = Path(temp_weights) st.success("Model weights uploaded and saved.") else: # Use default weights path weights_path = Path("Prithvi-WxC/examples/weights/prithvi.wxc.2300m.v1.pt") if not weights_path.exists(): st.error("Default model weights not found. Please upload model weights.") st.stop() return config, uploaded_surface_files, uploaded_vertical_files, clim_surf_path, clim_vert_path, config_path, weights_path else: # For other models, provide a simple file uploader st.subheader(f"{model_name} Model Data Upload") st.markdown("### Drag and Drop Your Data Files Here") uploaded_files = st.file_uploader( f"Upload Data Files for {model_name}", accept_multiple_files=True, key=f"{model_name.lower()}_uploader", type=["nc", "netcdf", "nc4"], ) return uploaded_files # Retrieve model-specific configuration and files if selected_model == "Prithvi": config, uploaded_surface_files, uploaded_vertical_files, clim_surf_path, clim_vert_path, config_path, weights_path = get_model_configuration(selected_model) else: uploaded_files = get_model_configuration(selected_model) st.write("---") # Horizontal separator # --- Run Inference Button --- if st.button("🚀 Run Inference"): with right_col: st.header("📈 Inference Progress & Visualization") # Initialize device try: torch.jit.enable_onednn_fusion(True) if torch.cuda.is_available(): device = torch.device("cuda") st.write(f"Using device: **{torch.cuda.get_device_name()}**") torch.backends.cudnn.benchmark = True torch.backends.cudnn.deterministic = True else: device = torch.device("cpu") st.write("Using device: **CPU**") except Exception as e: st.error("Error initializing device:") st.error(traceback.format_exc()) st.stop() # Set random seeds try: random.seed(42) if torch.cuda.is_available(): torch.cuda.manual_seed(42) torch.manual_seed(42) np.random.seed(42) except Exception as e: st.error("Error setting random seeds:") st.error(traceback.format_exc()) st.stop() # # Define variables and parameters based on dataset type # if dataset_type == "MERRA2": # surface_vars = [ # "EFLUX", # "GWETROOT", # "HFLUX", # "LAI", # "LWGAB", # "LWGEM", # "LWTUP", # "PS", # "QV2M", # "SLP", # "SWGNT", # "SWTNT", # "T2M", # "TQI", # "TQL", # "TQV", # "TS", # "U10M", # "V10M", # "Z0M", # ] # static_surface_vars = ["FRACI", "FRLAND", "FROCEAN", "PHIS"] # vertical_vars = ["CLOUD", "H", "OMEGA", "PL", "QI", "QL", "QV", "T", "U", "V"] # levels = [ # 34.0, # 39.0, # 41.0, # 43.0, # 44.0, # 45.0, # 48.0, # 51.0, # 53.0, # 56.0, # 63.0, # 68.0, # 71.0, # 72.0, # ] # elif dataset_type == "GEOS5": # # Define GEOS5 specific variables # surface_vars = [ # "GEOS5_EFLUX", # "GEOS5_GWETROOT", # "GEOS5_HFLUX", # "GEOS5_LAI", # "GEOS5_LWGAB", # "GEOS5_LWGEM", # "GEOS5_LWTUP", # "GEOS5_PS", # "GEOS5_QV2M", # "GEOS5_SLP", # "GEOS5_SWGNT", # "GEOS5_SWTNT", # "GEOS5_T2M", # "GEOS5_TQI", # "GEOS5_TQL", # "GEOS5_TQV", # "GEOS5_TS", # "GEOS5_U10M", # "GEOS5_V10M", # "GEOS5_Z0M", # ] # static_surface_vars = ["GEOS5_FRACI", "GEOS5_FRLAND", "GEOS5_FROCEAN", "GEOS5_PHIS"] # vertical_vars = ["GEOS5_CLOUD", "GEOS5_H", "GEOS5_OMEGA", "GEOS5_PL", "GEOS5_QI", "GEOS5_QL", "GEOS5_QV", "GEOS5_T", "GEOS5_U", "GEOS5_V"] # levels = [ # # Define levels specific to GEOS5 if different # 10.0, # 20.0, # 30.0, # 40.0, # 50.0, # 60.0, # 70.0, # 80.0, # ] # else: # st.error("Unsupported dataset type selected.") # st.stop() padding = {"level": [0, 0], "lat": [0, -1], "lon": [0, 0]} residual = "climate" masking_mode = "local" decoder_shifting = True masking_ratio = 0.99 positional_encoding = "fourier" # --- Initialize Dataset --- try: with st.spinner("Initializing dataset..."): if selected_model == "Prithvi": pass # # Validate climatology files # if not clim_files_exist and not (uploaded_clim_surface and uploaded_clim_vertical): # st.error("Climatology files are missing. Please upload both climatology surface and vertical files.") # st.stop() # dataset = Merra2Dataset( # time_range=time_range, # lead_times=lead_times, # input_times=input_times, # data_path_surface=surf_dir, # data_path_vertical=vert_dir, # climatology_path_surface=clim_surf_path, # climatology_path_vertical=clim_vert_path, # surface_vars=surface_vars, # static_surface_vars=static_surface_vars, # vertical_vars=vertical_vars, # levels=levels, # positional_encoding=positional_encoding, # ) # assert len(dataset) > 0, "There doesn't seem to be any valid data." elif selected_model == "Aurora": # TODO just temporary, replace this if uploaded_files: temp_file_paths = [] # List to store paths of temporary files try: # Save each uploaded file to a temporary file save_uploaded_files(uploaded_files) ds = load_dataset(st.session_state.temp_file_paths) # Now, use xarray to open the multiple files if ds: st.success("Files successfully loaded!") st.session_state.ds_subset = ds # print(ds) ds = ds.fillna(ds.mean()) desired_levels = [50, 100, 150, 200, 250, 300, 400, 500, 600, 700, 850, 925, 1000] # Ensure that the 'lev' dimension exists if 'lev' not in ds.dims: raise ValueError("The dataset does not contain a 'lev' (pressure level) dimension.") # Define the _prepare function def _prepare(x: np.ndarray, i: int) -> torch.Tensor: # Select previous and current time steps selected = x[[i - 6, i]] # Add a batch dimension selected = selected[None] # Ensure data is contiguous selected = selected.copy() # Convert to PyTorch tensor return torch.from_numpy(selected) # Adjust latitudes and longitudes lat = ds.lat.values * -1 lon = ds.lon.values + 180 # Subset the dataset to only include the desired pressure levels ds_subset = ds.sel(lev=desired_levels, method="nearest") # Verify that all desired levels are present present_levels = ds_subset.lev.values missing_levels = set(desired_levels) - set(present_levels) if missing_levels: raise ValueError(f"The following desired pressure levels are missing in the dataset: {missing_levels}") # Extract pressure levels after subsetting lev = ds_subset.lev.values # Pressure levels in hPa # Prepare surface variables at 1000 hPa try: lev_index_1000 = np.where(lev == 1000)[0][0] except IndexError: raise ValueError("1000 hPa level not found in the 'lev' dimension after subsetting.") T_surface = ds_subset.T.isel(lev=lev_index_1000).compute() U_surface = ds_subset.U.isel(lev=lev_index_1000).compute() V_surface = ds_subset.V.isel(lev=lev_index_1000).compute() SLP = ds_subset.SLP.compute() # Reorder static variables (selecting the first time index to remove the time dimension) PHIS = ds_subset.PHIS.isel(time=0).compute() # Prepare atmospheric variables for the desired pressure levels excluding 1000 hPa atmos_levels = [int(level) for level in lev if level != 1000] T_atm = (ds_subset.T.sel(lev=atmos_levels)).compute() U_atm = (ds_subset.U.sel(lev=atmos_levels)).compute() V_atm = (ds_subset.V.sel(lev=atmos_levels)).compute() # Select time index num_times = ds_subset.time.size i = 6 # Adjust as needed (1 <= i < num_times) if i >= num_times or i < 1: raise IndexError("Time index i is out of bounds.") time_values = ds_subset.time.values current_time = np.datetime64(time_values[i]).astype('datetime64[s]').astype(datetime) # Prepare surface variables surf_vars = { "2t": _prepare(T_surface.values, i), # Two-meter temperature "10u": _prepare(U_surface.values, i), # Ten-meter eastward wind "10v": _prepare(V_surface.values, i), # Ten-meter northward wind "msl": _prepare(SLP.values, i), # Mean sea-level pressure } # Prepare static variables (now 2D tensors) static_vars = { "z": torch.from_numpy(PHIS.values.copy()), # Geopotential (h, w) # Add 'lsm' and 'slt' if available and needed } # Prepare atmospheric variables atmos_vars = { "t": _prepare(T_atm.values, i), # Temperature at desired levels "u": _prepare(U_atm.values, i), # Eastward wind at desired levels "v": _prepare(V_atm.values, i), # Southward wind at desired levels } # Define metadata metadata = Metadata( lat=torch.from_numpy(lat.copy()), lon=torch.from_numpy(lon.copy()), time=(current_time,), atmos_levels=tuple(atmos_levels), # Only the desired atmospheric levels ) # Create the Batch object batch = Batch( surf_vars=surf_vars, static_vars=static_vars, atmos_vars=atmos_vars, metadata=metadata ) # Display the dataset or perform further processing st.session_state['batch'] = batch except Exception as e: st.error(f"An error occurred: {e}") # finally: # # Clean up: Remove temporary files # for path in temp_file_paths: # try: # os.remove(path) # except Exception as e: # st.warning(f"Could not delete temp file {path}: {e}") else: # For other models, implement their specific dataset initialization # Placeholder: Replace with actual dataset initialization for other models dataset = None # Replace with actual dataset st.warning("Dataset initialization for this model is not implemented yet.") st.stop() st.success("Dataset initialized successfully.") except Exception as e: st.error("Error initializing dataset:") st.error(traceback.format_exc()) st.stop() # --- Load Scalers --- try: with st.spinner("Loading scalers..."): if selected_model == "Prithvi": pass # # Assuming the scaler paths are the same as climatology paths # surf_in_scal_path = clim_surf_path # vert_in_scal_path = clim_vert_path # surf_out_scal_path = Path(clim_surf_path.parent) / "anomaly_variance_surface.nc" # vert_out_scal_path = Path(clim_vert_path.parent) / "anomaly_variance_vertical.nc" # # Check if output scaler files exist # if not surf_out_scal_path.exists() or not vert_out_scal_path.exists(): # st.error("Anomaly variance scaler files are missing.") # st.stop() # in_mu, in_sig = input_scalers( # surface_vars, # vertical_vars, # levels, # surf_in_scal_path, # vert_in_scal_path, # ) # output_sig = output_scalers( # surface_vars, # vertical_vars, # levels, # surf_out_scal_path, # vert_out_scal_path, # ) # static_mu, static_sig = static_input_scalers( # surf_in_scal_path, # static_surface_vars, # ) else: # Load scalers for other models if applicable # Placeholder: Replace with actual scaler loading for other models in_mu, in_sig = None, None output_sig = None static_mu, static_sig = None, None st.success("Scalers loaded successfully.") except Exception as e: st.error("Error loading scalers:") st.error(traceback.format_exc()) st.stop() # --- Load Configuration --- try: with st.spinner("Loading configuration..."): if selected_model == "Prithvi": with open(config_path, "r") as f: config = yaml.safe_load(f) # Validate config required_params = [ "in_channels", "input_size_time", "in_channels_static", "input_scalers_epsilon", "static_input_scalers_epsilon", "n_lats_px", "n_lons_px", "patch_size_px", "mask_unit_size_px", "embed_dim", "n_blocks_encoder", "n_blocks_decoder", "mlp_multiplier", "n_heads", "dropout", "drop_path", "parameter_dropout" ] missing_params = [param for param in required_params if param not in config.get("params", {})] if missing_params: st.error(f"Missing configuration parameters: {missing_params}") st.stop() else: # Load configuration for other models if applicable # Placeholder: Replace with actual configuration loading for other models config = {} st.success("Configuration loaded successfully.") except Exception as e: st.error("Error loading configuration:") st.error(traceback.format_exc()) st.stop() # --- Initialize the Model --- try: with st.spinner("Initializing model..."): if selected_model == "Prithvi": model = PrithviWxC( in_channels=config["params"]["in_channels"], input_size_time=config["params"]["input_size_time"], in_channels_static=config["params"]["in_channels_static"], input_scalers_mu=in_mu, input_scalers_sigma=in_sig, input_scalers_epsilon=config["params"]["input_scalers_epsilon"], static_input_scalers_mu=static_mu, static_input_scalers_sigma=static_sig, static_input_scalers_epsilon=config["params"]["static_input_scalers_epsilon"], output_scalers=output_sig**0.5, n_lats_px=config["params"]["n_lats_px"], n_lons_px=config["params"]["n_lons_px"], patch_size_px=config["params"]["patch_size_px"], mask_unit_size_px=config["params"]["mask_unit_size_px"], mask_ratio_inputs=masking_ratio, embed_dim=config["params"]["embed_dim"], n_blocks_encoder=config["params"]["n_blocks_encoder"], n_blocks_decoder=config["params"]["n_blocks_decoder"], mlp_multiplier=config["params"]["mlp_multiplier"], n_heads=config["params"]["n_heads"], dropout=config["params"]["dropout"], drop_path=config["params"]["drop_path"], parameter_dropout=config["params"]["parameter_dropout"], residual=residual, masking_mode=masking_mode, decoder_shifting=decoder_shifting, positional_encoding=positional_encoding, checkpoint_encoder=[], checkpoint_decoder=[], ) elif selected_model == "Aurora": pass else: # Initialize other models here # Placeholder: Replace with actual model initialization for other models model = None st.warning("Model initialization for this model is not implemented yet.") st.stop() # model.to(device) st.success("Model initialized successfully.") except Exception as e: st.error("Error initializing model:") st.error(traceback.format_exc()) st.stop() # --- Load Model Weights --- try: with st.spinner("Loading model weights..."): if selected_model == "Prithvi": state_dict = torch.load(weights_path, map_location=device) if "model_state" in state_dict: state_dict = state_dict["model_state"] model.load_state_dict(state_dict, strict=True) model.to(device) else: # Load weights for other models if applicable # Placeholder: Replace with actual weight loading for other models pass st.success("Model weights loaded successfully.") except Exception as e: st.error("Error loading model weights:") st.error(traceback.format_exc()) st.stop() # --- Prepare Data Batch --- try: with st.spinner("Preparing data batch..."): if selected_model == "Prithvi": data = next(iter(dataset)) batch = preproc([data], padding) for k, v in batch.items(): if isinstance(v, torch.Tensor): batch[k] = v.to(device) elif selected_model == "Aurora": batch = batch.regrid(res=0.25) else: # Prepare data batch for other models # Placeholder: Replace with actual data preparation for other models batch = None st.success("Data batch prepared successfully.") except Exception as e: st.error("Error preparing data batch:") st.error(traceback.format_exc()) st.stop() # --- Run Inference --- try: with st.spinner("Running model inference..."): if selected_model == "Prithvi": model.eval() with torch.no_grad(): out = model(batch) elif selected_model == "Aurora": model = Aurora(use_lora=False) # model = Aurora() model.load_checkpoint("microsoft/aurora", "aurora-0.25-pretrained.ckpt") # model.load_checkpoint("microsoft/aurora", "aurora-0.25-pretrained.ckpt") model.eval() # model = model.to("cuda") # Uncomment if using a GPU with torch.inference_mode(): out = [pred.to("cpu") for pred in rollout(model, batch, steps=2)] model = model.to("cpu") st.session_state.model = model else: # Run inference for other models # Placeholder: Replace with actual inference code for other models out = torch.randn(1, 10, 180, 360) # Dummy tensor st.success("Model inference completed successfully.") st.session_state['out'] = out except Exception as e: st.error("Error during model inference:") st.error(traceback.format_exc()) st.stop() # --- Visualization Settings --- st.markdown("## 📊 Visualization Settings") if 'out' in st.session_state and 'batch' in st.session_state and selected_model == "Prithvi": # Display the shape of the output tensor out_tensor = st.session_state['out'] st.write(f"**Output tensor shape:** {out_tensor.shape}") # Ensure the output tensor has at least 4 dimensions (batch, variables, lat, lon) if out_tensor.ndim < 4: st.error("The output tensor does not have the expected number of dimensions (batch, variables, lat, lon).") st.stop() # Get the number of variables num_variables = out_tensor.shape[1] # Define variable names (update with your actual variable names) variable_names = [f"Variable_{i}" for i in range(num_variables)] # Visualization settings col1, col2 = st.columns(2) with col1: # Select variable to plot selected_variable_name = st.selectbox( "Select Variable to Plot", options=variable_names, index=0, help="Choose the variable you want to visualize." ) # Select plot type plot_type = st.selectbox( "Select Plot Type", options=["Contour", "Heatmap"], index=0, help="Choose the type of plot to display." ) with col2: # Select color map cmap = st.selectbox( "Select Color Map", options=plt.colormaps(), index=plt.colormaps().index("viridis"), help="Choose the color map for the plot." ) # Set number of levels (for contour plot) if plot_type == "Contour": num_levels = st.slider( "Number of Contour Levels", min_value=5, max_value=100, value=20, step=5, help="Set the number of contour levels." ) else: num_levels = None # Find the index based on the selected name variable_index = variable_names.index(selected_variable_name) # Extract the selected variable selected_variable = out_tensor[0, variable_index].cpu().numpy() # Generate latitude and longitude arrays lat = np.linspace(-90, 90, selected_variable.shape[0]) lon = np.linspace(-180, 180, selected_variable.shape[1]) X, Y = np.meshgrid(lon, lat) # Plot the selected variable st.markdown(f"### Plot of {selected_variable_name}") # Matplotlib figure fig, ax = plt.subplots(figsize=(10, 6)) if plot_type == "Contour": # Generate the contour plot contour = ax.contourf(X, Y, selected_variable, levels=num_levels, cmap=cmap) elif plot_type == "Heatmap": # Generate the heatmap contour = ax.imshow(selected_variable, extent=[-180, 180, -90, 90], cmap=cmap, origin='lower', aspect='auto') # Add a color bar cbar = plt.colorbar(contour, ax=ax) cbar.set_label(f'{selected_variable_name}', fontsize=12) # Set aspect ratio and labels ax.set_xlabel("Longitude", fontsize=12) ax.set_ylabel("Latitude", fontsize=12) ax.set_title(f"{selected_variable_name}", fontsize=14) # Display the plot in Streamlit st.pyplot(fig) # Optional: Provide interactive Plotly plot st.markdown("#### Interactive Plot") if plot_type == "Contour": fig_plotly = go.Figure(data=go.Contour( z=selected_variable, x=lon, y=lat, colorscale=cmap, contours=dict( coloring='fill', showlabels=True, labelfont=dict(size=12, color='white'), ncontours=num_levels ) )) elif plot_type == "Heatmap": fig_plotly = go.Figure(data=go.Heatmap( z=selected_variable, x=lon, y=lat, colorscale=cmap )) fig_plotly.update_layout( xaxis_title="Longitude", yaxis_title="Latitude", autosize=False, width=800, height=600, ) st.plotly_chart(fig_plotly) elif 'out' in st.session_state and selected_model == "Aurora" and st.session_state['out'] is not None: preds = st.session_state['out'] ds_subset = st.session_state.get('ds_subset', None) batch = st.session_state.get('batch', None) # **Determine Available Levels** # For example, let's assume levels range from 0 to max_level_index # You need to replace 'max_level_index' with the actual maximum level index in your data try: # Assuming 'lev' dimension exists and is 1D levels = preds[0].atmos_vars["t"].shape[2] # Adjust based on your data structure level_indices = list(range(levels)) except Exception as e: st.error("Error determining available levels:") st.error(traceback.format_exc()) levels = None # Set to None if levels cannot be determined if levels is not None: # **Add a Slider for Level Selection** selected_level = st.slider( 'Select Level', min_value=0, max_value=levels - 1, value=11, # Default level index step=1, help="Select the vertical level for plotting." ) # Loop through predictions and ground truths for idx in range(len(preds)): pred = preds[idx] pred_time = pred.metadata.time[0] # Display prediction time st.write(f"### Prediction Time: {pred_time}") # **Extract Data at Selected Level** try: # Update indices with the selected_level pred_data = pred.atmos_vars["t"][0][0][selected_level].numpy() - 273.15 truth_data = ds_subset.T.isel(lev=selected_level)[idx].values - 273.15 except Exception as e: st.error("Error extracting data for plotting:") st.error(traceback.format_exc()) continue # Extract latitude and longitude try: lat = np.array(pred.metadata.lat) # Assuming 'lat' is 1D lon = np.array(pred.metadata.lon) # Assuming 'lon' is 1D except Exception as e: st.error("Error extracting latitude and longitude:") st.error(traceback.format_exc()) continue # Create a meshgrid for plotting lon_grid, lat_grid = np.meshgrid(lon, lat) # Create a Matplotlib figure with Cartopy projection fig, axes = plt.subplots( 1, 3, figsize=(18, 6), subplot_kw={'projection': ccrs.PlateCarree()} ) # **Ground Truth Plot** im1 = axes[0].imshow( truth_data, extent=[lon.min(), lon.max(), lat.min(), lat.max()], origin='lower', cmap='coolwarm', transform=ccrs.PlateCarree() ) axes[0].set_title(f"Ground Truth at Level {selected_level} - {pred_time}") axes[0].set_xlabel('Longitude') axes[0].set_ylabel('Latitude') plt.colorbar(im1, ax=axes[0], orientation='horizontal', pad=0.05) # **Prediction Plot** im2 = axes[1].imshow( pred_data, extent=[lon.min(), lon.max(), lat.min(), lat.max()], origin='lower', cmap='coolwarm', transform=ccrs.PlateCarree() ) axes[1].set_title(f"Prediction at Level {selected_level} - {pred_time}") axes[1].set_xlabel('Longitude') axes[1].set_ylabel('Latitude') plt.colorbar(im2, ax=axes[1], orientation='horizontal', pad=0.05) plt.tight_layout() # Display the plot in Streamlit st.pyplot(fig) else: st.error("Could not determine the available levels in the data.") else: st.warning("No output available to display or visualization is not implemented for this model.") # --- End of Inference Button --- else: with right_col: st.header("🖥️ Visualization & Progress") st.info("Awaiting inference to display results.")