Spaces:
Sleeping
Sleeping
import numpy as np | |
import plotly.express as px | |
import plotly.graph_objects as go | |
import streamlit as st | |
from src.subpages.page import Context, Page | |
class HiddenStatesVisualizer: | |
def __init__(self, context: Context): | |
self.context = context | |
self.df = context.df_tokens_merged.copy() | |
def _reduce_dim_svd(self, X, n_iter: int, random_state=42): | |
# Implement your SVD reduction here | |
pass | |
def _reduce_dim_pca(self, X, random_state=42): | |
# Implement your PCA reduction here | |
pass | |
def _reduce_dim_umap(self, X, n_neighbors=5, min_dist=0.1, metric="euclidean"): | |
# Implement your UMAP reduction here | |
pass | |
def visualize_hidden_states(self): | |
st.title("Embeddings") | |
with st.expander("💡", expanded=True): | |
st.write( | |
"For every token in the dataset, we take its hidden state and project it onto a two-dimensional plane. Data points are colored by label/prediction, with disagreements signified by a small black border." | |
) | |
col1, _, col2 = st.columns([9 / 32, 1 / 32, 22 / 32]) | |
dim_algo = "SVD" | |
n_tokens = 100 | |
with col1: | |
st.subheader("Settings") | |
n_tokens = st.slider( | |
"#tokens", | |
key="n_tokens", | |
min_value=100, | |
max_value=len(self.df["tokens"].unique()), | |
step=100, | |
) | |
dim_algo = st.selectbox("Dimensionality reduction algorithm", ["SVD", "PCA", "UMAP"]) | |
if dim_algo == "SVD": | |
svd_n_iter = st.slider( | |
"#iterations", | |
key="svd_n_iter", | |
min_value=1, | |
max_value=10, | |
step=1, | |
) | |
elif dim_algo == "UMAP": | |
umap_n_neighbors = st.slider( | |
"#neighbors", | |
key="umap_n_neighbors", | |
min_value=2, | |
max_value=100, | |
step=1, | |
) | |
umap_min_dist = st.number_input( | |
"Min distance", key="umap_min_dist", value=0.1, min_value=0.0, max_value=1.0 | |
) | |
umap_metric = st.selectbox( | |
"Metric", ["euclidean", "manhattan", "chebyshev", "minkowski"] | |
) | |
else: | |
pass | |
with col2: | |
sents = self.df.groupby("ids").apply(lambda x: " ".join(x["tokens"].tolist())) | |
X = np.array(self.df["hidden_states"].tolist()) | |
transformed_hidden_states = None | |
if dim_algo == "SVD": | |
transformed_hidden_states = self._reduce_dim_svd(X, n_iter=svd_n_iter) # type: ignore | |
elif dim_algo == "PCA": | |
transformed_hidden_states = self._reduce_dim_pca(X) | |
elif dim_algo == "UMAP": | |
transformed_hidden_states = self._reduce_dim_umap( | |
X, n_neighbors=umap_n_neighbors, min_dist=umap_min_dist, metric=umap_metric # type: ignore | |
) | |
assert isinstance(transformed_hidden_states, np.ndarray) | |
self.df["x"] = transformed_hidden_states[:, 0] | |
self.df["y"] = transformed_hidden_states[:, 1] | |
self.df["sent0"] = self.df["ids"].map(lambda x: " ".join(sents[x][0:50].split())) | |
self.df["sent1"] = self.df["ids"].map(lambda x: " ".join(sents[x][50:100].split())) | |
self.df["sent2"] = self.df["ids"].map(lambda x: " ".join(sents[x][100:150].split())) | |
self.df["sent3"] = self.df["ids"].map(lambda x: " ".join(sents[x][150:200].split())) | |
self.df["sent4"] = self.df["ids"].map(lambda x: " ".join(sents[x][200:250].split())) | |
self.df["disagreements"] = self.df["labels"] != self.df["preds"] | |
subset = self.df[:n_tokens] | |
disagreements_trace = go.Scatter( | |
x=subset[subset["disagreements"]]["x"], | |
y=subset[subset["disagreements"]]["y"], | |
mode="markers", | |
marker=dict( | |
size=6, | |
color="rgba(0,0,0,0)", | |
line=dict(width=1), | |
), | |
hoverinfo="skip", | |
) | |
st.subheader("Projection Results") | |
fig = px.scatter( | |
subset, | |
x="x", | |
y="y", | |
color="labels", | |
hover_data=["ids", "preds", "sent0", "sent1", "sent2", "sent3", "sent4"], | |
hover_name="tokens", | |
title="Colored by label", | |
) | |
fig.add_trace(disagreements_trace) | |
st.plotly_chart(fig) | |
fig = px.scatter( | |
subset, | |
x="x", | |
y="y", | |
color="preds", | |
hover_data=["ids", "labels", "sent0", "sent1", "sent2", "sent3", "sent4"], | |
hover_name="tokens", | |
title="Colored by prediction", | |
) | |
fig.add_trace(disagreements_trace) | |
st.plotly_chart(fig) | |
class HiddenStatesPage(Page): | |
name = "Hidden States" | |
icon = "grid-3x3" | |
def render(self, context: Context): | |
visualizer = HiddenStatesVisualizer(context) | |
visualizer.visualize_hidden_states() |