|
import streamlit as st |
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
import plotly.graph_objects as go |
|
import cartopy.crs as ccrs |
|
|
|
def plot_prithvi_output(out_tensor): |
|
if out_tensor is None: |
|
st.warning("No output available for plotting.") |
|
return |
|
|
|
|
|
st.markdown("## π Visualization Settings") |
|
|
|
if out_tensor.ndim < 4: |
|
st.error("The output tensor does not have the expected dimensions.") |
|
return |
|
|
|
num_variables = out_tensor.shape[1] |
|
variable_names = [f"Variable_{i}" for i in range(num_variables)] |
|
|
|
col1, col2 = st.columns(2) |
|
with col1: |
|
selected_variable_name = st.selectbox( |
|
"Select Variable to Plot", |
|
options=variable_names, |
|
index=0, |
|
help="Choose the variable to visualize." |
|
) |
|
plot_type = st.selectbox("Select Plot Type", ["Contour", "Heatmap"], index=0) |
|
|
|
with col2: |
|
cmap = st.selectbox("Select Color Map", options=plt.colormaps(), index=plt.colormaps().index("viridis")) |
|
if plot_type == "Contour": |
|
num_levels = st.slider("Number of Contour Levels", 5, 100, 20, 5) |
|
else: |
|
num_levels = None |
|
|
|
variable_index = variable_names.index(selected_variable_name) |
|
selected_variable = out_tensor[0, variable_index].cpu().numpy() |
|
lat = np.linspace(-90, 90, selected_variable.shape[0]) |
|
lon = np.linspace(-180, 180, selected_variable.shape[1]) |
|
X, Y = np.meshgrid(lon, lat) |
|
|
|
st.markdown(f"### Plot of {selected_variable_name}") |
|
fig, ax = plt.subplots(figsize=(10, 6)) |
|
if plot_type == "Contour": |
|
contour = ax.contourf(X, Y, selected_variable, levels=num_levels, cmap=cmap) |
|
else: |
|
contour = ax.imshow(selected_variable, extent=[-180, 180, -90, 90], cmap=cmap, origin='lower', aspect='auto') |
|
|
|
cbar = plt.colorbar(contour, ax=ax) |
|
cbar.set_label(f'{selected_variable_name}', fontsize=12) |
|
ax.set_xlabel("Longitude", fontsize=12) |
|
ax.set_ylabel("Latitude", fontsize=12) |
|
ax.set_title(selected_variable_name, fontsize=14) |
|
st.pyplot(fig) |
|
|
|
|
|
st.markdown("#### Interactive Plot") |
|
if plot_type == "Contour": |
|
fig_plotly = go.Figure(data=go.Contour( |
|
z=selected_variable, |
|
x=lon, |
|
y=lat, |
|
colorscale=cmap, |
|
contours=dict(coloring='fill', showlabels=True, labelfont=dict(size=12, color='white'), ncontours=num_levels) |
|
)) |
|
else: |
|
fig_plotly = go.Figure(data=go.Heatmap(z=selected_variable, x=lon, y=lat, colorscale=cmap)) |
|
|
|
fig_plotly.update_layout( |
|
xaxis_title="Longitude", |
|
yaxis_title="Latitude", |
|
width=800, |
|
height=600, |
|
) |
|
st.plotly_chart(fig_plotly) |
|
|
|
|
|
def plot_aurora_output(preds, ds_subset): |
|
if preds is None or ds_subset is None: |
|
st.error("No predictions or dataset subset available for visualization.") |
|
return |
|
|
|
try: |
|
levels = preds[0].atmos_vars["t"].shape[2] |
|
except: |
|
st.error("Could not determine available levels in the data.") |
|
return |
|
|
|
selected_level = st.slider('Select Level', 0, levels - 1, 11, 1) |
|
|
|
for idx, pred in enumerate(preds): |
|
pred_time = pred.metadata.time[0] |
|
|
|
try: |
|
pred_data = pred.atmos_vars["t"][0][0][selected_level].numpy() - 273.15 |
|
truth_data = ds_subset.T.isel(lev=selected_level)[idx].values - 273.15 |
|
except Exception as e: |
|
st.error("Error extracting data for plotting:") |
|
st.error(e) |
|
continue |
|
|
|
lat = np.array(pred.metadata.lat) |
|
lon = np.array(pred.metadata.lon) |
|
|
|
fig, axes = plt.subplots(1, 3, figsize=(18, 6), subplot_kw={'projection': ccrs.PlateCarree()}) |
|
|
|
|
|
im1 = axes[0].imshow( |
|
truth_data, extent=[lon.min(), lon.max(), lat.min(), lat.max()], |
|
origin='lower', cmap='coolwarm', transform=ccrs.PlateCarree() |
|
) |
|
axes[0].set_title(f"Ground Truth at Level {selected_level} - {pred_time}") |
|
plt.colorbar(im1, ax=axes[0], orientation='horizontal', pad=0.05) |
|
|
|
|
|
im2 = axes[1].imshow( |
|
pred_data, extent=[lon.min(), lon.max(), lat.min(), lat.max()], |
|
origin='lower', cmap='coolwarm', transform=ccrs.PlateCarree() |
|
) |
|
axes[1].set_title(f"Prediction at Level {selected_level} - {pred_time}") |
|
plt.colorbar(im2, ax=axes[1], orientation='horizontal', pad=0.05) |
|
|
|
plt.tight_layout() |
|
st.pyplot(fig) |
|
|