File size: 4,623 Bytes
100edb4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import streamlit as st
import matplotlib.pyplot as plt
import numpy as np
import plotly.graph_objects as go
import cartopy.crs as ccrs

def plot_prithvi_output(out_tensor):
    if out_tensor is None:
        st.warning("No output available for plotting.")
        return

    # Example visualization UI for Prithvi:
    st.markdown("## 📊 Visualization Settings")
    # Extract shape and variable names as needed
    if out_tensor.ndim < 4:
        st.error("The output tensor does not have the expected dimensions.")
        return

    num_variables = out_tensor.shape[1]
    variable_names = [f"Variable_{i}" for i in range(num_variables)]

    col1, col2 = st.columns(2)
    with col1:
        selected_variable_name = st.selectbox(
            "Select Variable to Plot",
            options=variable_names,
            index=0,
            help="Choose the variable to visualize."
        )
        plot_type = st.selectbox("Select Plot Type", ["Contour", "Heatmap"], index=0)

    with col2:
        cmap = st.selectbox("Select Color Map", options=plt.colormaps(), index=plt.colormaps().index("viridis"))
        if plot_type == "Contour":
            num_levels = st.slider("Number of Contour Levels", 5, 100, 20, 5)
        else:
            num_levels = None

    variable_index = variable_names.index(selected_variable_name)
    selected_variable = out_tensor[0, variable_index].cpu().numpy()
    lat = np.linspace(-90, 90, selected_variable.shape[0])
    lon = np.linspace(-180, 180, selected_variable.shape[1])
    X, Y = np.meshgrid(lon, lat)

    st.markdown(f"### Plot of {selected_variable_name}")
    fig, ax = plt.subplots(figsize=(10, 6))
    if plot_type == "Contour":
        contour = ax.contourf(X, Y, selected_variable, levels=num_levels, cmap=cmap)
    else:
        contour = ax.imshow(selected_variable, extent=[-180, 180, -90, 90], cmap=cmap, origin='lower', aspect='auto')

    cbar = plt.colorbar(contour, ax=ax)
    cbar.set_label(f'{selected_variable_name}', fontsize=12)
    ax.set_xlabel("Longitude", fontsize=12)
    ax.set_ylabel("Latitude", fontsize=12)
    ax.set_title(selected_variable_name, fontsize=14)
    st.pyplot(fig)

    # Plotly interactive plot
    st.markdown("#### Interactive Plot")
    if plot_type == "Contour":
        fig_plotly = go.Figure(data=go.Contour(
            z=selected_variable,
            x=lon,
            y=lat,
            colorscale=cmap,
            contours=dict(coloring='fill', showlabels=True, labelfont=dict(size=12, color='white'), ncontours=num_levels)
        ))
    else:
        fig_plotly = go.Figure(data=go.Heatmap(z=selected_variable, x=lon, y=lat, colorscale=cmap))

    fig_plotly.update_layout(
        xaxis_title="Longitude",
        yaxis_title="Latitude",
        width=800,
        height=600,
    )
    st.plotly_chart(fig_plotly)


def plot_aurora_output(preds, ds_subset):
    if preds is None or ds_subset is None:
        st.error("No predictions or dataset subset available for visualization.")
        return

    try:
        levels = preds[0].atmos_vars["t"].shape[2]
    except:
        st.error("Could not determine available levels in the data.")
        return

    selected_level = st.slider('Select Level', 0, levels - 1, 11, 1)

    for idx, pred in enumerate(preds):
        pred_time = pred.metadata.time[0]

        try:
            pred_data = pred.atmos_vars["t"][0][0][selected_level].numpy() - 273.15
            truth_data = ds_subset.T.isel(lev=selected_level)[idx].values - 273.15
        except Exception as e:
            st.error("Error extracting data for plotting:")
            st.error(e)
            continue

        lat = np.array(pred.metadata.lat)
        lon = np.array(pred.metadata.lon)

        fig, axes = plt.subplots(1, 3, figsize=(18, 6), subplot_kw={'projection': ccrs.PlateCarree()})

        # Ground Truth
        im1 = axes[0].imshow(
            truth_data, extent=[lon.min(), lon.max(), lat.min(), lat.max()],
            origin='lower', cmap='coolwarm', transform=ccrs.PlateCarree()
        )
        axes[0].set_title(f"Ground Truth at Level {selected_level} - {pred_time}")
        plt.colorbar(im1, ax=axes[0], orientation='horizontal', pad=0.05)

        # Prediction
        im2 = axes[1].imshow(
            pred_data, extent=[lon.min(), lon.max(), lat.min(), lat.max()],
            origin='lower', cmap='coolwarm', transform=ccrs.PlateCarree()
        )
        axes[1].set_title(f"Prediction at Level {selected_level} - {pred_time}")
        plt.colorbar(im2, ax=axes[1], orientation='horizontal', pad=0.05)

        plt.tight_layout()
        st.pyplot(fig)