qq1990's picture
init
100edb4
import streamlit as st
import matplotlib.pyplot as plt
import numpy as np
import plotly.graph_objects as go
import cartopy.crs as ccrs
def plot_prithvi_output(out_tensor):
if out_tensor is None:
st.warning("No output available for plotting.")
return
# Example visualization UI for Prithvi:
st.markdown("## πŸ“Š Visualization Settings")
# Extract shape and variable names as needed
if out_tensor.ndim < 4:
st.error("The output tensor does not have the expected dimensions.")
return
num_variables = out_tensor.shape[1]
variable_names = [f"Variable_{i}" for i in range(num_variables)]
col1, col2 = st.columns(2)
with col1:
selected_variable_name = st.selectbox(
"Select Variable to Plot",
options=variable_names,
index=0,
help="Choose the variable to visualize."
)
plot_type = st.selectbox("Select Plot Type", ["Contour", "Heatmap"], index=0)
with col2:
cmap = st.selectbox("Select Color Map", options=plt.colormaps(), index=plt.colormaps().index("viridis"))
if plot_type == "Contour":
num_levels = st.slider("Number of Contour Levels", 5, 100, 20, 5)
else:
num_levels = None
variable_index = variable_names.index(selected_variable_name)
selected_variable = out_tensor[0, variable_index].cpu().numpy()
lat = np.linspace(-90, 90, selected_variable.shape[0])
lon = np.linspace(-180, 180, selected_variable.shape[1])
X, Y = np.meshgrid(lon, lat)
st.markdown(f"### Plot of {selected_variable_name}")
fig, ax = plt.subplots(figsize=(10, 6))
if plot_type == "Contour":
contour = ax.contourf(X, Y, selected_variable, levels=num_levels, cmap=cmap)
else:
contour = ax.imshow(selected_variable, extent=[-180, 180, -90, 90], cmap=cmap, origin='lower', aspect='auto')
cbar = plt.colorbar(contour, ax=ax)
cbar.set_label(f'{selected_variable_name}', fontsize=12)
ax.set_xlabel("Longitude", fontsize=12)
ax.set_ylabel("Latitude", fontsize=12)
ax.set_title(selected_variable_name, fontsize=14)
st.pyplot(fig)
# Plotly interactive plot
st.markdown("#### Interactive Plot")
if plot_type == "Contour":
fig_plotly = go.Figure(data=go.Contour(
z=selected_variable,
x=lon,
y=lat,
colorscale=cmap,
contours=dict(coloring='fill', showlabels=True, labelfont=dict(size=12, color='white'), ncontours=num_levels)
))
else:
fig_plotly = go.Figure(data=go.Heatmap(z=selected_variable, x=lon, y=lat, colorscale=cmap))
fig_plotly.update_layout(
xaxis_title="Longitude",
yaxis_title="Latitude",
width=800,
height=600,
)
st.plotly_chart(fig_plotly)
def plot_aurora_output(preds, ds_subset):
if preds is None or ds_subset is None:
st.error("No predictions or dataset subset available for visualization.")
return
try:
levels = preds[0].atmos_vars["t"].shape[2]
except:
st.error("Could not determine available levels in the data.")
return
selected_level = st.slider('Select Level', 0, levels - 1, 11, 1)
for idx, pred in enumerate(preds):
pred_time = pred.metadata.time[0]
try:
pred_data = pred.atmos_vars["t"][0][0][selected_level].numpy() - 273.15
truth_data = ds_subset.T.isel(lev=selected_level)[idx].values - 273.15
except Exception as e:
st.error("Error extracting data for plotting:")
st.error(e)
continue
lat = np.array(pred.metadata.lat)
lon = np.array(pred.metadata.lon)
fig, axes = plt.subplots(1, 3, figsize=(18, 6), subplot_kw={'projection': ccrs.PlateCarree()})
# Ground Truth
im1 = axes[0].imshow(
truth_data, extent=[lon.min(), lon.max(), lat.min(), lat.max()],
origin='lower', cmap='coolwarm', transform=ccrs.PlateCarree()
)
axes[0].set_title(f"Ground Truth at Level {selected_level} - {pred_time}")
plt.colorbar(im1, ax=axes[0], orientation='horizontal', pad=0.05)
# Prediction
im2 = axes[1].imshow(
pred_data, extent=[lon.min(), lon.max(), lat.min(), lat.max()],
origin='lower', cmap='coolwarm', transform=ccrs.PlateCarree()
)
axes[1].set_title(f"Prediction at Level {selected_level} - {pred_time}")
plt.colorbar(im2, ax=axes[1], orientation='horizontal', pad=0.05)
plt.tight_layout()
st.pyplot(fig)