Spaces:
Sleeping
Sleeping
import streamlit as st | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import plotly.graph_objects as go | |
import cartopy.crs as ccrs | |
def plot_prithvi_output(out_tensor): | |
if out_tensor is None: | |
st.warning("No output available for plotting.") | |
return | |
# Example visualization UI for Prithvi: | |
st.markdown("## π Visualization Settings") | |
# Extract shape and variable names as needed | |
if out_tensor.ndim < 4: | |
st.error("The output tensor does not have the expected dimensions.") | |
return | |
num_variables = out_tensor.shape[1] | |
variable_names = [f"Variable_{i}" for i in range(num_variables)] | |
col1, col2 = st.columns(2) | |
with col1: | |
selected_variable_name = st.selectbox( | |
"Select Variable to Plot", | |
options=variable_names, | |
index=0, | |
help="Choose the variable to visualize." | |
) | |
plot_type = st.selectbox("Select Plot Type", ["Contour", "Heatmap"], index=0) | |
with col2: | |
cmap = st.selectbox("Select Color Map", options=plt.colormaps(), index=plt.colormaps().index("viridis")) | |
if plot_type == "Contour": | |
num_levels = st.slider("Number of Contour Levels", 5, 100, 20, 5) | |
else: | |
num_levels = None | |
variable_index = variable_names.index(selected_variable_name) | |
selected_variable = out_tensor[0, variable_index].cpu().numpy() | |
lat = np.linspace(-90, 90, selected_variable.shape[0]) | |
lon = np.linspace(-180, 180, selected_variable.shape[1]) | |
X, Y = np.meshgrid(lon, lat) | |
st.markdown(f"### Plot of {selected_variable_name}") | |
fig, ax = plt.subplots(figsize=(10, 6)) | |
if plot_type == "Contour": | |
contour = ax.contourf(X, Y, selected_variable, levels=num_levels, cmap=cmap) | |
else: | |
contour = ax.imshow(selected_variable, extent=[-180, 180, -90, 90], cmap=cmap, origin='lower', aspect='auto') | |
cbar = plt.colorbar(contour, ax=ax) | |
cbar.set_label(f'{selected_variable_name}', fontsize=12) | |
ax.set_xlabel("Longitude", fontsize=12) | |
ax.set_ylabel("Latitude", fontsize=12) | |
ax.set_title(selected_variable_name, fontsize=14) | |
st.pyplot(fig) | |
# Plotly interactive plot | |
st.markdown("#### Interactive Plot") | |
if plot_type == "Contour": | |
fig_plotly = go.Figure(data=go.Contour( | |
z=selected_variable, | |
x=lon, | |
y=lat, | |
colorscale=cmap, | |
contours=dict(coloring='fill', showlabels=True, labelfont=dict(size=12, color='white'), ncontours=num_levels) | |
)) | |
else: | |
fig_plotly = go.Figure(data=go.Heatmap(z=selected_variable, x=lon, y=lat, colorscale=cmap)) | |
fig_plotly.update_layout( | |
xaxis_title="Longitude", | |
yaxis_title="Latitude", | |
width=800, | |
height=600, | |
) | |
st.plotly_chart(fig_plotly) | |
def plot_aurora_output(preds, ds_subset): | |
if preds is None or ds_subset is None: | |
st.error("No predictions or dataset subset available for visualization.") | |
return | |
try: | |
levels = preds[0].atmos_vars["t"].shape[2] | |
except: | |
st.error("Could not determine available levels in the data.") | |
return | |
selected_level = st.slider('Select Level', 0, levels - 1, 11, 1) | |
for idx, pred in enumerate(preds): | |
pred_time = pred.metadata.time[0] | |
try: | |
pred_data = pred.atmos_vars["t"][0][0][selected_level].numpy() - 273.15 | |
truth_data = ds_subset.T.isel(lev=selected_level)[idx].values - 273.15 | |
except Exception as e: | |
st.error("Error extracting data for plotting:") | |
st.error(e) | |
continue | |
lat = np.array(pred.metadata.lat) | |
lon = np.array(pred.metadata.lon) | |
fig, axes = plt.subplots(1, 3, figsize=(18, 6), subplot_kw={'projection': ccrs.PlateCarree()}) | |
# Ground Truth | |
im1 = axes[0].imshow( | |
truth_data, extent=[lon.min(), lon.max(), lat.min(), lat.max()], | |
origin='lower', cmap='coolwarm', transform=ccrs.PlateCarree() | |
) | |
axes[0].set_title(f"Ground Truth at Level {selected_level} - {pred_time}") | |
plt.colorbar(im1, ax=axes[0], orientation='horizontal', pad=0.05) | |
# Prediction | |
im2 = axes[1].imshow( | |
pred_data, extent=[lon.min(), lon.max(), lat.min(), lat.max()], | |
origin='lower', cmap='coolwarm', transform=ccrs.PlateCarree() | |
) | |
axes[1].set_title(f"Prediction at Level {selected_level} - {pred_time}") | |
plt.colorbar(im2, ax=axes[1], orientation='horizontal', pad=0.05) | |
plt.tight_layout() | |
st.pyplot(fig) | |