import streamlit as st import matplotlib.pyplot as plt import numpy as np import plotly.graph_objects as go import cartopy.crs as ccrs def plot_prithvi_output(out_tensor): if out_tensor is None: st.warning("No output available for plotting.") return # Example visualization UI for Prithvi: st.markdown("## 📊 Visualization Settings") # Extract shape and variable names as needed if out_tensor.ndim < 4: st.error("The output tensor does not have the expected dimensions.") return num_variables = out_tensor.shape[1] variable_names = [f"Variable_{i}" for i in range(num_variables)] col1, col2 = st.columns(2) with col1: selected_variable_name = st.selectbox( "Select Variable to Plot", options=variable_names, index=0, help="Choose the variable to visualize." ) plot_type = st.selectbox("Select Plot Type", ["Contour", "Heatmap"], index=0) with col2: cmap = st.selectbox("Select Color Map", options=plt.colormaps(), index=plt.colormaps().index("viridis")) if plot_type == "Contour": num_levels = st.slider("Number of Contour Levels", 5, 100, 20, 5) else: num_levels = None variable_index = variable_names.index(selected_variable_name) selected_variable = out_tensor[0, variable_index].cpu().numpy() lat = np.linspace(-90, 90, selected_variable.shape[0]) lon = np.linspace(-180, 180, selected_variable.shape[1]) X, Y = np.meshgrid(lon, lat) st.markdown(f"### Plot of {selected_variable_name}") fig, ax = plt.subplots(figsize=(10, 6)) if plot_type == "Contour": contour = ax.contourf(X, Y, selected_variable, levels=num_levels, cmap=cmap) else: contour = ax.imshow(selected_variable, extent=[-180, 180, -90, 90], cmap=cmap, origin='lower', aspect='auto') cbar = plt.colorbar(contour, ax=ax) cbar.set_label(f'{selected_variable_name}', fontsize=12) ax.set_xlabel("Longitude", fontsize=12) ax.set_ylabel("Latitude", fontsize=12) ax.set_title(selected_variable_name, fontsize=14) st.pyplot(fig) # Plotly interactive plot st.markdown("#### Interactive Plot") if plot_type == "Contour": fig_plotly = go.Figure(data=go.Contour( z=selected_variable, x=lon, y=lat, colorscale=cmap, contours=dict(coloring='fill', showlabels=True, labelfont=dict(size=12, color='white'), ncontours=num_levels) )) else: fig_plotly = go.Figure(data=go.Heatmap(z=selected_variable, x=lon, y=lat, colorscale=cmap)) fig_plotly.update_layout( xaxis_title="Longitude", yaxis_title="Latitude", width=800, height=600, ) st.plotly_chart(fig_plotly) def plot_aurora_output(preds, ds_subset): if preds is None or ds_subset is None: st.error("No predictions or dataset subset available for visualization.") return try: levels = preds[0].atmos_vars["t"].shape[2] except: st.error("Could not determine available levels in the data.") return selected_level = st.slider('Select Level', 0, levels - 1, 11, 1) for idx, pred in enumerate(preds): pred_time = pred.metadata.time[0] try: pred_data = pred.atmos_vars["t"][0][0][selected_level].numpy() - 273.15 truth_data = ds_subset.T.isel(lev=selected_level)[idx].values - 273.15 except Exception as e: st.error("Error extracting data for plotting:") st.error(e) continue lat = np.array(pred.metadata.lat) lon = np.array(pred.metadata.lon) fig, axes = plt.subplots(1, 3, figsize=(18, 6), subplot_kw={'projection': ccrs.PlateCarree()}) # Ground Truth im1 = axes[0].imshow( truth_data, extent=[lon.min(), lon.max(), lat.min(), lat.max()], origin='lower', cmap='coolwarm', transform=ccrs.PlateCarree() ) axes[0].set_title(f"Ground Truth at Level {selected_level} - {pred_time}") plt.colorbar(im1, ax=axes[0], orientation='horizontal', pad=0.05) # Prediction im2 = axes[1].imshow( pred_data, extent=[lon.min(), lon.max(), lat.min(), lat.max()], origin='lower', cmap='coolwarm', transform=ccrs.PlateCarree() ) axes[1].set_title(f"Prediction at Level {selected_level} - {pred_time}") plt.colorbar(im2, ax=axes[1], orientation='horizontal', pad=0.05) plt.tight_layout() st.pyplot(fig)