File size: 3,574 Bytes
3ea658b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import io
import textwrap

import numpy as np
import pandas as pd
import streamlit as st
from sklearn.manifold import TSNE
import plotly.express as px

# --------------  Helper functions -------------------------------------------
EXAMPLE_SHAPES = {
    "Cube (3-D, 8 vertices)": np.array([
        [0, 0, 0], [0, 0, 1],
        [0, 1, 0], [0, 1, 1],
        [1, 0, 0], [1, 0, 1],
        [1, 1, 0], [1, 1, 1]
    ]),
    "Square pyramid (3-D, 5 vertices)": np.array([
        [-1, -1,  0],
        [ 1, -1,  0],
        [ 1,  1,  0],
        [-1,  1,  0],
        [ 0,  0,  1]
    ])
}


def parse_text_points(text: str) -> np.ndarray:
    """
    Parse a multiline string of comma- or whitespace-separated numbers
    into an (n_points, n_dims) array.
    """
    cleaned = textwrap.dedent(text.strip())
    rows = [row for row in cleaned.splitlines() if row.strip()]
    data = [list(map(float, row.replace(",", " ").split())) for row in rows]
    return np.array(data, dtype=float)


def run_tsne(data: np.ndarray, perplexity: float, seed: int) -> np.ndarray:
    tsne = TSNE(
        n_components=2,
        perplexity=perplexity,
        random_state=seed,
        init="pca"
    )
    return tsne.fit_transform(data)
# ---------------------------------------------------------------------------


st.title("πŸŒ€ t-SNE Explorer for n-D Point Clouds")
st.markdown(
    """
    Upload or paste your points, choose parameters, and see how
    **t-SNE** flattens them into 2-D.  
    *Example shapes* are provided for quick experimentation.
    """
)

# --- Sidebar controls -------------------------------------------------------
with st.sidebar:
    st.header("1️⃣  Choose data source")
    source = st.radio(
        "Data input method",
        ["Example shape", "Upload CSV/TXT", "Paste raw text"]
    )

    if source == "Example shape":
        shape_key = st.selectbox("Pick a shape", list(EXAMPLE_SHAPES.keys()))
        data_raw = EXAMPLE_SHAPES[shape_key]

    elif source == "Upload CSV/TXT":
        file = st.file_uploader("Upload coordinates file (*.csv / *.txt)")
        if file:
            text = io.StringIO(file.getvalue().decode("utf-8")).read()
            data_raw = parse_text_points(text)
        else:
            st.stop()

    else:  # Paste text
        placeholder = "e.g.\n0,0,0\n0,0,1\n0,1,0\n..."
        text = st.text_area("Paste coordinates (one point per line)", height=200, placeholder=placeholder)
        if not text.strip():
            st.stop()
        data_raw = parse_text_points(text)

    st.divider()
    st.header("2️⃣  t-SNE parameters")
    perplexity = st.slider("Perplexity", 5.0, 50.0, 30.0, 1.0)
    seed = st.number_input("Random seed", value=42, step=1)
    run_button = st.button("Run t-SNE πŸš€")

# --- Main area --------------------------------------------------------------
if run_button:
    if data_raw.ndim != 2 or data_raw.shape[0] < 2:
        st.error("Need at least two points; check your input.")
        st.stop()

    if perplexity >= data_raw.shape[0]:
        st.error("Perplexity must be less than the number of points.")
        st.stop()

    embedding = run_tsne(data_raw, perplexity, seed)
    df_plot = pd.DataFrame(embedding, columns=["x", "y"])

    st.subheader("2-D embedding")
    fig = px.scatter(df_plot, x="x", y="y", width=700, height=500)
    fig.update_traces(marker=dict(size=10))
    fig.update_layout(margin=dict(l=20, r=20, t=30, b=20))
    st.plotly_chart(fig, use_container_width=True)

    with st.expander("Show raw data"):
        st.write(pd.DataFrame(data_raw))