tsne / app.py
euler314's picture
Create app.py
3ea658b verified
raw
history blame
3.57 kB
import io
import textwrap
import numpy as np
import pandas as pd
import streamlit as st
from sklearn.manifold import TSNE
import plotly.express as px
# -------------- Helper functions -------------------------------------------
EXAMPLE_SHAPES = {
"Cube (3-D, 8 vertices)": np.array([
[0, 0, 0], [0, 0, 1],
[0, 1, 0], [0, 1, 1],
[1, 0, 0], [1, 0, 1],
[1, 1, 0], [1, 1, 1]
]),
"Square pyramid (3-D, 5 vertices)": np.array([
[-1, -1, 0],
[ 1, -1, 0],
[ 1, 1, 0],
[-1, 1, 0],
[ 0, 0, 1]
])
}
def parse_text_points(text: str) -> np.ndarray:
"""
Parse a multiline string of comma- or whitespace-separated numbers
into an (n_points, n_dims) array.
"""
cleaned = textwrap.dedent(text.strip())
rows = [row for row in cleaned.splitlines() if row.strip()]
data = [list(map(float, row.replace(",", " ").split())) for row in rows]
return np.array(data, dtype=float)
def run_tsne(data: np.ndarray, perplexity: float, seed: int) -> np.ndarray:
tsne = TSNE(
n_components=2,
perplexity=perplexity,
random_state=seed,
init="pca"
)
return tsne.fit_transform(data)
# ---------------------------------------------------------------------------
st.title("πŸŒ€ t-SNE Explorer for n-D Point Clouds")
st.markdown(
"""
Upload or paste your points, choose parameters, and see how
**t-SNE** flattens them into 2-D.
*Example shapes* are provided for quick experimentation.
"""
)
# --- Sidebar controls -------------------------------------------------------
with st.sidebar:
st.header("1️⃣ Choose data source")
source = st.radio(
"Data input method",
["Example shape", "Upload CSV/TXT", "Paste raw text"]
)
if source == "Example shape":
shape_key = st.selectbox("Pick a shape", list(EXAMPLE_SHAPES.keys()))
data_raw = EXAMPLE_SHAPES[shape_key]
elif source == "Upload CSV/TXT":
file = st.file_uploader("Upload coordinates file (*.csv / *.txt)")
if file:
text = io.StringIO(file.getvalue().decode("utf-8")).read()
data_raw = parse_text_points(text)
else:
st.stop()
else: # Paste text
placeholder = "e.g.\n0,0,0\n0,0,1\n0,1,0\n..."
text = st.text_area("Paste coordinates (one point per line)", height=200, placeholder=placeholder)
if not text.strip():
st.stop()
data_raw = parse_text_points(text)
st.divider()
st.header("2️⃣ t-SNE parameters")
perplexity = st.slider("Perplexity", 5.0, 50.0, 30.0, 1.0)
seed = st.number_input("Random seed", value=42, step=1)
run_button = st.button("Run t-SNE πŸš€")
# --- Main area --------------------------------------------------------------
if run_button:
if data_raw.ndim != 2 or data_raw.shape[0] < 2:
st.error("Need at least two points; check your input.")
st.stop()
if perplexity >= data_raw.shape[0]:
st.error("Perplexity must be less than the number of points.")
st.stop()
embedding = run_tsne(data_raw, perplexity, seed)
df_plot = pd.DataFrame(embedding, columns=["x", "y"])
st.subheader("2-D embedding")
fig = px.scatter(df_plot, x="x", y="y", width=700, height=500)
fig.update_traces(marker=dict(size=10))
fig.update_layout(margin=dict(l=20, r=20, t=30, b=20))
st.plotly_chart(fig, use_container_width=True)
with st.expander("Show raw data"):
st.write(pd.DataFrame(data_raw))