File size: 5,367 Bytes
2d4811a
 
 
 
 
 
 
 
68a5c1a
 
 
 
2d4811a
68a5c1a
 
 
2d4811a
68a5c1a
 
 
2d4811a
68a5c1a
 
 
2d4811a
68a5c1a
2d4811a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68a5c1a
2d4811a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68a5c1a
2d4811a
68a5c1a
2d4811a
 
68a5c1a
2d4811a
68a5c1a
2d4811a
68a5c1a
2d4811a
 
 
 
68a5c1a
 
 
 
 
 
 
 
 
 
2d4811a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68a5c1a
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import numpy as np
import plotly.express as px
import plotly.graph_objects as go
import streamlit as st

from src.subpages.page import Context, Page


class HiddenStatesVisualizer:
    def __init__(self, context: Context):
        self.context = context
        self.df = context.df_tokens_merged.copy()

    def _reduce_dim_svd(self, X, n_iter: int, random_state=42):
        # Implement your SVD reduction here
        pass

    def _reduce_dim_pca(self, X, random_state=42):
        # Implement your PCA reduction here
        pass

    def _reduce_dim_umap(self, X, n_neighbors=5, min_dist=0.1, metric="euclidean"):
        # Implement your UMAP reduction here
        pass

    def visualize_hidden_states(self):
        st.title("Embeddings")

        with st.expander("💡", expanded=True):
            st.write(
                "For every token in the dataset, we take its hidden state and project it onto a two-dimensional plane. Data points are colored by label/prediction, with disagreements signified by a small black border."
            )

        col1, _, col2 = st.columns([9 / 32, 1 / 32, 22 / 32])
        dim_algo = "SVD"
        n_tokens = 100

        with col1:
            st.subheader("Settings")
            n_tokens = st.slider(
                "#tokens",
                key="n_tokens",
                min_value=100,
                max_value=len(self.df["tokens"].unique()),
                step=100,
            )

            dim_algo = st.selectbox("Dimensionality reduction algorithm", ["SVD", "PCA", "UMAP"])
            if dim_algo == "SVD":
                svd_n_iter = st.slider(
                    "#iterations",
                    key="svd_n_iter",
                    min_value=1,
                    max_value=10,
                    step=1,
                )
            elif dim_algo == "UMAP":
                umap_n_neighbors = st.slider(
                    "#neighbors",
                    key="umap_n_neighbors",
                    min_value=2,
                    max_value=100,
                    step=1,
                )
                umap_min_dist = st.number_input(
                    "Min distance", key="umap_min_dist", value=0.1, min_value=0.0, max_value=1.0
                )
                umap_metric = st.selectbox(
                    "Metric", ["euclidean", "manhattan", "chebyshev", "minkowski"]
                )
            else:
                pass

        with col2:
            sents = self.df.groupby("ids").apply(lambda x: " ".join(x["tokens"].tolist()))

            X = np.array(self.df["hidden_states"].tolist())
            transformed_hidden_states = None
            if dim_algo == "SVD":
                transformed_hidden_states = self._reduce_dim_svd(X, n_iter=svd_n_iter)  # type: ignore
            elif dim_algo == "PCA":
                transformed_hidden_states = self._reduce_dim_pca(X)
            elif dim_algo == "UMAP":
                transformed_hidden_states = self._reduce_dim_umap(
                    X, n_neighbors=umap_n_neighbors, min_dist=umap_min_dist, metric=umap_metric  # type: ignore
                )

            assert isinstance(transformed_hidden_states, np.ndarray)
            self.df["x"] = transformed_hidden_states[:, 0]
            self.df["y"] = transformed_hidden_states[:, 1]
            self.df["sent0"] = self.df["ids"].map(lambda x: " ".join(sents[x][0:50].split()))
            self.df["sent1"] = self.df["ids"].map(lambda x: " ".join(sents[x][50:100].split()))
            self.df["sent2"] = self.df["ids"].map(lambda x: " ".join(sents[x][100:150].split()))
            self.df["sent3"] = self.df["ids"].map(lambda x: " ".join(sents[x][150:200].split()))
            self.df["sent4"] = self.df["ids"].map(lambda x: " ".join(sents[x][200:250].split()))
            self.df["disagreements"] = self.df["labels"] != self.df["preds"]

            subset = self.df[:n_tokens]
            disagreements_trace = go.Scatter(
                x=subset[subset["disagreements"]]["x"],
                y=subset[subset["disagreements"]]["y"],
                mode="markers",
                marker=dict(
                    size=6,
                    color="rgba(0,0,0,0)",
                    line=dict(width=1),
                ),
                hoverinfo="skip",
            )

            st.subheader("Projection Results")

            fig = px.scatter(
                subset,
                x="x",
                y="y",
                color="labels",
                hover_data=["ids", "preds", "sent0", "sent1", "sent2", "sent3", "sent4"],
                hover_name="tokens",
                title="Colored by label",
            )
            fig.add_trace(disagreements_trace)
            st.plotly_chart(fig)

            fig = px.scatter(
                subset,
                x="x",
                y="y",
                color="preds",
                hover_data=["ids", "labels", "sent0", "sent1", "sent2", "sent3", "sent4"],
                hover_name="tokens",
                title="Colored by prediction",
            )
            fig.add_trace(disagreements_trace)
            st.plotly_chart(fig)


class HiddenStatesPage(Page):
    name = "Hidden States"
    icon = "grid-3x3"

    def render(self, context: Context):
        visualizer = HiddenStatesVisualizer(context)
        visualizer.visualize_hidden_states()