tsne / app.py
euler314's picture
Update app.py
6988d0c verified
import io
import textwrap
import itertools
import numpy as np
import pandas as pd
import streamlit as st
from sklearn.manifold import TSNE, trustworthiness
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans, DBSCAN
import umap.umap_ as umap
import plotly.express as px
from sklearn.datasets import make_swiss_roll
# --- Example shapes (some generated on demand) --------------------------------
def generate_hypercube(n=4):
return np.array(list(itertools.product([0, 1], repeat=n)), dtype=float)
def generate_simplex(n=3):
eye = np.eye(n, dtype=float)
origin = np.zeros((1, n), dtype=float)
return np.vstack([eye, origin])
def generate_swiss_roll(n_samples=500, noise=0.05):
X, _ = make_swiss_roll(n_samples=n_samples, noise=noise)
return X
EXAMPLE_SHAPES = {
"Cube (3-D, 8 pts)": np.array([
[0,0,0],[0,0,1],[0,1,0],[0,1,1],
[1,0,0],[1,0,1],[1,1,0],[1,1,1]
], dtype=float),
"Square pyramid (3-D, 5 pts)": np.array([
[-1,-1,0],[1,-1,0],[1,1,0],[-1,1,0],[0,0,1]
], dtype=float),
"4-D hypercube (16 pts)": generate_hypercube(4),
"3-simplex (4 pts in 3-D)": generate_simplex(3),
"Swiss roll (500 pts, 3-D)": generate_swiss_roll,
}
# --- Parsing & embedding -----------------------------------------------------
def parse_text_points(text: str) -> np.ndarray:
txt = textwrap.dedent(text.strip())
rows = [r for r in txt.splitlines() if r.strip()]
data = [list(map(float, r.replace(",", " ").split())) for r in rows]
return np.array(data, dtype=float)
def run_tsne(data, perp, seed):
ts = TSNE(n_components=2, perplexity=perp, random_state=seed, init="pca")
emb = ts.fit_transform(data)
return emb, ts.kl_divergence_
def run_pca(data):
pca = PCA(n_components=2)
return pca.fit_transform(data), None
def run_umap(data, n_neighbors, min_dist, seed):
um = umap.UMAP(n_components=2, n_neighbors=n_neighbors,
min_dist=min_dist, random_state=seed)
return um.fit_transform(data), None
# --- Streamlit App -----------------------------------------------------------
st.set_page_config(layout="wide")
st.title("πŸŒ€ Dimensionality Reduction Explorer")
st.write("""
Upload or paste your n-D points, pick an algorithm (t-SNE/PCA/UMAP),
optionally cluster, and see the 2-D embedding.
""")
# Sidebar ─────────────────────────────────────────────────────────────────────
with st.sidebar:
st.header("1️⃣ Data Input")
mode = st.radio("Source", ["Example shape","Upload CSV/TXT","Paste text"])
if mode == "Example shape":
key = st.selectbox("Choose example", list(EXAMPLE_SHAPES.keys()))
src = EXAMPLE_SHAPES[key]
data_raw = src() if callable(src) else src
elif mode == "Upload CSV/TXT":
up = st.file_uploader("Upload file", type=["csv","txt"])
if up:
txt = io.StringIO(up.getvalue().decode("utf-8")).read()
data_raw = parse_text_points(txt)
else:
st.stop()
else:
placeholder = "e.g.\n0,0,0\n0,0,1\n0,1,0\n..."
txt = st.text_area("Paste coordinates", height=200, placeholder=placeholder)
if not txt.strip():
st.stop()
data_raw = parse_text_points(txt)
st.header("2️⃣ Algorithm & Params")
algo = st.selectbox("Method", ["t-SNE","PCA","UMAP"])
seed = st.number_input("Random seed", value=42, step=1)
if algo == "t-SNE":
perp = st.slider("Perplexity", 5.0, 50.0, 30.0, 1.0)
elif algo == "UMAP":
neighbors = st.slider("n_neighbors", 5, 200, 15, 5)
min_dist = st.slider("min_dist", 0.0, 0.99, 0.1, 0.01)
st.header("3️⃣ Clustering (optional)")
do_cluster = st.checkbox("Cluster embedding")
if do_cluster:
cluster_algo = st.selectbox("Algorithm", ["KMeans","DBSCAN"])
if cluster_algo == "KMeans":
n_clusters = st.slider("n_clusters", 2, 10, 3, 1)
else:
eps = st.slider("DBSCAN eps", 0.1, 5.0, 0.5, 0.1)
st.markdown("---")
run = st.button("Run & Visualize πŸš€")
# Main ────────────────────────────────────────────────────────────────────────
if run:
pts = data_raw
if pts.ndim != 2 or pts.shape[0] < 2:
st.error("Need at least two points in an (n_pts Γ— n_dims) array.")
st.stop()
# run chosen reducer
if algo == "t-SNE":
emb, kl = run_tsne(pts, perp, seed)
elif algo == "PCA":
emb, kl = run_pca(pts)
else:
emb, kl = run_umap(pts, neighbors, min_dist, seed)
# dynamic trustworthiness
n_samples = pts.shape[0]
k_max = (n_samples - 1) // 2
if k_max >= 1:
tw = trustworthiness(pts, emb, n_neighbors=k_max)
else:
tw = None
# clustering & plotting
df = pd.DataFrame(emb, columns=["x","y"])
if do_cluster:
if cluster_algo == "KMeans":
labels = KMeans(n_clusters=n_clusters, random_state=seed).fit_predict(emb)
else:
labels = DBSCAN(eps=eps).fit_predict(emb)
df["cluster"] = labels.astype(str)
fig = px.scatter(df, x="x", y="y", color="cluster",
title=f"{algo} embedding with {cluster_algo}", width=700, height=500)
else:
fig = px.scatter(df, x="x", y="y",
title=f"{algo} embedding", width=700, height=500)
fig.update_traces(marker=dict(size=8))
fig.update_layout(margin=dict(l=20, r=20, t=40, b=20))
# display
st.subheader("2-D Embedding")
st.plotly_chart(fig, use_container_width=True)
if tw is not None:
st.markdown(f"**Trustworthiness (k={k_max}):** {tw:.3f}")
else:
st.markdown("**Trustworthiness:** Not enough samples to compute (need β‰₯3 points).")
if kl is not None:
st.markdown(f"**t-SNE KL divergence:** {kl:.3f}")
# download CSV
csv = df.to_csv(index=False).encode("utf-8")
st.download_button(
"Download embedding as CSV",
data=csv,
file_name="embedding.csv",
mime="text/csv"
)
with st.expander("Show original data"):
st.write(pts)
if algo == "t-SNE":
with st.expander("🧠 How t-SNE works"):
st.markdown(r"""
1. **High-D similarities**
Convert pairwise distances \(d_{ij}\) into conditional probabilities
\[
p_{j|i} = \frac{\exp\!\bigl(-\|x_i - x_j\|^2 / 2\sigma_i^2\bigr)}
{\sum_{k\neq i}\exp\!\bigl(-\|x_i - x_k\|^2 / 2\sigma_i^2\bigr)}
\]
then symmetrize to \(p_{ij}=(p_{j|i}+p_{i|j})/2n\).
2. **Low-D affinities**
In 2-D we use a Student-t kernel:
\[
q_{ij} = \frac{\bigl(1 + \|y_i - y_j\|^2\bigr)^{-1}}
{\sum_{k\neq l}\bigl(1 + \|y_k - y_l\|^2\bigr)^{-1}}
\]
3. **Minimize KL divergence**
Find \(\{y_i\}\) to minimize
\[
KL(P\|Q)
= \sum_{i\neq j} p_{ij}\,\log\frac{p_{ij}}{q_{ij}}
\]
via gradient descentβ€”preserving local structure while pushing dissimilar points apart.
**Key parameter – perplexity**
Controls each \(\sigma_i\) by solving
\(\mathrm{Perp}(p_{i\cdot})=2^{-\sum_j p_{j|i}\log_2 p_{j|i}}\),
intuitively setting an β€œeffective # neighbors” (5–50 typical).
""")