CodeMonkeyXL / src /subpages /hidden_states.py
K00B404's picture
Update src/subpages/hidden_states.py
68a5c1a verified
import numpy as np
import plotly.express as px
import plotly.graph_objects as go
import streamlit as st
from src.subpages.page import Context, Page
class HiddenStatesVisualizer:
def __init__(self, context: Context):
self.context = context
self.df = context.df_tokens_merged.copy()
def _reduce_dim_svd(self, X, n_iter: int, random_state=42):
# Implement your SVD reduction here
pass
def _reduce_dim_pca(self, X, random_state=42):
# Implement your PCA reduction here
pass
def _reduce_dim_umap(self, X, n_neighbors=5, min_dist=0.1, metric="euclidean"):
# Implement your UMAP reduction here
pass
def visualize_hidden_states(self):
st.title("Embeddings")
with st.expander("💡", expanded=True):
st.write(
"For every token in the dataset, we take its hidden state and project it onto a two-dimensional plane. Data points are colored by label/prediction, with disagreements signified by a small black border."
)
col1, _, col2 = st.columns([9 / 32, 1 / 32, 22 / 32])
dim_algo = "SVD"
n_tokens = 100
with col1:
st.subheader("Settings")
n_tokens = st.slider(
"#tokens",
key="n_tokens",
min_value=100,
max_value=len(self.df["tokens"].unique()),
step=100,
)
dim_algo = st.selectbox("Dimensionality reduction algorithm", ["SVD", "PCA", "UMAP"])
if dim_algo == "SVD":
svd_n_iter = st.slider(
"#iterations",
key="svd_n_iter",
min_value=1,
max_value=10,
step=1,
)
elif dim_algo == "UMAP":
umap_n_neighbors = st.slider(
"#neighbors",
key="umap_n_neighbors",
min_value=2,
max_value=100,
step=1,
)
umap_min_dist = st.number_input(
"Min distance", key="umap_min_dist", value=0.1, min_value=0.0, max_value=1.0
)
umap_metric = st.selectbox(
"Metric", ["euclidean", "manhattan", "chebyshev", "minkowski"]
)
else:
pass
with col2:
sents = self.df.groupby("ids").apply(lambda x: " ".join(x["tokens"].tolist()))
X = np.array(self.df["hidden_states"].tolist())
transformed_hidden_states = None
if dim_algo == "SVD":
transformed_hidden_states = self._reduce_dim_svd(X, n_iter=svd_n_iter) # type: ignore
elif dim_algo == "PCA":
transformed_hidden_states = self._reduce_dim_pca(X)
elif dim_algo == "UMAP":
transformed_hidden_states = self._reduce_dim_umap(
X, n_neighbors=umap_n_neighbors, min_dist=umap_min_dist, metric=umap_metric # type: ignore
)
assert isinstance(transformed_hidden_states, np.ndarray)
self.df["x"] = transformed_hidden_states[:, 0]
self.df["y"] = transformed_hidden_states[:, 1]
self.df["sent0"] = self.df["ids"].map(lambda x: " ".join(sents[x][0:50].split()))
self.df["sent1"] = self.df["ids"].map(lambda x: " ".join(sents[x][50:100].split()))
self.df["sent2"] = self.df["ids"].map(lambda x: " ".join(sents[x][100:150].split()))
self.df["sent3"] = self.df["ids"].map(lambda x: " ".join(sents[x][150:200].split()))
self.df["sent4"] = self.df["ids"].map(lambda x: " ".join(sents[x][200:250].split()))
self.df["disagreements"] = self.df["labels"] != self.df["preds"]
subset = self.df[:n_tokens]
disagreements_trace = go.Scatter(
x=subset[subset["disagreements"]]["x"],
y=subset[subset["disagreements"]]["y"],
mode="markers",
marker=dict(
size=6,
color="rgba(0,0,0,0)",
line=dict(width=1),
),
hoverinfo="skip",
)
st.subheader("Projection Results")
fig = px.scatter(
subset,
x="x",
y="y",
color="labels",
hover_data=["ids", "preds", "sent0", "sent1", "sent2", "sent3", "sent4"],
hover_name="tokens",
title="Colored by label",
)
fig.add_trace(disagreements_trace)
st.plotly_chart(fig)
fig = px.scatter(
subset,
x="x",
y="y",
color="preds",
hover_data=["ids", "labels", "sent0", "sent1", "sent2", "sent3", "sent4"],
hover_name="tokens",
title="Colored by prediction",
)
fig.add_trace(disagreements_trace)
st.plotly_chart(fig)
class HiddenStatesPage(Page):
name = "Hidden States"
icon = "grid-3x3"
def render(self, context: Context):
visualizer = HiddenStatesVisualizer(context)
visualizer.visualize_hidden_states()