Update app.py
Browse files
app.py
CHANGED
@@ -1,6 +1,8 @@
|
|
1 |
import io
|
2 |
import textwrap
|
3 |
import itertools
|
|
|
|
|
4 |
|
5 |
import numpy as np
|
6 |
import pandas as pd
|
@@ -10,9 +12,10 @@ from sklearn.decomposition import PCA
|
|
10 |
from sklearn.cluster import KMeans, DBSCAN
|
11 |
import umap.umap_ as umap
|
12 |
import plotly.express as px
|
|
|
13 |
from sklearn.datasets import make_swiss_roll
|
14 |
|
15 |
-
#
|
16 |
def generate_hypercube(n=4):
|
17 |
return np.array(list(itertools.product([0, 1], repeat=n)), dtype=float)
|
18 |
|
@@ -26,19 +29,19 @@ def generate_swiss_roll(n_samples=500, noise=0.05):
|
|
26 |
return X
|
27 |
|
28 |
EXAMPLE_SHAPES = {
|
29 |
-
"Cube (3
|
30 |
[0,0,0],[0,0,1],[0,1,0],[0,1,1],
|
31 |
[1,0,0],[1,0,1],[1,1,0],[1,1,1]
|
32 |
], dtype=float),
|
33 |
-
"Square pyramid (3
|
34 |
[-1,-1,0],[1,-1,0],[1,1,0],[-1,1,0],[0,0,1]
|
35 |
], dtype=float),
|
36 |
-
"4
|
37 |
-
"3
|
38 |
-
"Swiss roll (500 pts, 3
|
39 |
}
|
40 |
|
41 |
-
#
|
42 |
def parse_text_points(text: str) -> np.ndarray:
|
43 |
txt = textwrap.dedent(text.strip())
|
44 |
rows = [r for r in txt.splitlines() if r.strip()]
|
@@ -59,41 +62,54 @@ def run_umap(data, n_neighbors, min_dist, seed):
|
|
59 |
min_dist=min_dist, random_state=seed)
|
60 |
return um.fit_transform(data), None
|
61 |
|
62 |
-
#
|
63 |
st.set_page_config(layout="wide")
|
64 |
st.title("π Dimensionality Reduction Explorer")
|
|
|
65 |
st.write("""
|
66 |
-
Upload or
|
67 |
-
optionally cluster, and
|
|
|
68 |
""")
|
69 |
|
70 |
-
# Sidebar
|
71 |
with st.sidebar:
|
72 |
st.header("1οΈβ£ Data Input")
|
73 |
-
mode = st.radio("Source", ["Example shape","Upload CSV/TXT","Paste text"])
|
|
|
|
|
|
|
74 |
if mode == "Example shape":
|
75 |
key = st.selectbox("Choose example", list(EXAMPLE_SHAPES.keys()))
|
76 |
src = EXAMPLE_SHAPES[key]
|
77 |
data_raw = src() if callable(src) else src
|
|
|
|
|
78 |
elif mode == "Upload CSV/TXT":
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
|
|
84 |
st.stop()
|
85 |
-
|
|
|
|
|
|
|
|
|
86 |
placeholder = "e.g.\n0,0,0\n0,0,1\n0,1,0\n..."
|
87 |
txt = st.text_area("Paste coordinates", height=200, placeholder=placeholder)
|
88 |
if not txt.strip():
|
89 |
st.stop()
|
90 |
data_raw = parse_text_points(txt)
|
|
|
91 |
|
92 |
st.header("2οΈβ£ Algorithm & Params")
|
93 |
-
algo = st.selectbox("Method", ["t
|
94 |
seed = st.number_input("Random seed", value=42, step=1)
|
95 |
|
96 |
-
if algo == "t
|
97 |
perp = st.slider("Perplexity", 1.0, 50.0, 30.0, 1.0)
|
98 |
elif algo == "UMAP":
|
99 |
neighbors = st.slider("n_neighbors", 5, 200, 15, 5)
|
@@ -102,7 +118,7 @@ with st.sidebar:
|
|
102 |
st.header("3οΈβ£ Clustering (optional)")
|
103 |
do_cluster = st.checkbox("Cluster embedding")
|
104 |
if do_cluster:
|
105 |
-
cluster_algo = st.selectbox("Algorithm", ["KMeans","DBSCAN"])
|
106 |
if cluster_algo == "KMeans":
|
107 |
n_clusters = st.slider("n_clusters", 2, 10, 3, 1)
|
108 |
else:
|
@@ -111,98 +127,95 @@ with st.sidebar:
|
|
111 |
st.markdown("---")
|
112 |
run = st.button("Run & Visualize π")
|
113 |
|
114 |
-
# Main
|
115 |
-
|
116 |
-
pts = data_raw
|
117 |
if pts.ndim != 2 or pts.shape[0] < 2:
|
118 |
-
st.error("
|
119 |
-
|
120 |
|
121 |
-
#
|
122 |
-
if algo == "t
|
123 |
emb, kl = run_tsne(pts, perp, seed)
|
124 |
elif algo == "PCA":
|
125 |
emb, kl = run_pca(pts)
|
126 |
else:
|
127 |
emb, kl = run_umap(pts, neighbors, min_dist, seed)
|
128 |
|
129 |
-
#
|
130 |
n_samples = pts.shape[0]
|
131 |
k_max = (n_samples - 1) // 2
|
132 |
-
if k_max >= 1
|
133 |
-
tw = trustworthiness(pts, emb, n_neighbors=k_max)
|
134 |
-
else:
|
135 |
-
tw = None
|
136 |
|
137 |
-
#
|
138 |
-
df = pd.DataFrame(emb, columns=["x","y"])
|
|
|
|
|
139 |
if do_cluster:
|
140 |
if cluster_algo == "KMeans":
|
141 |
labels = KMeans(n_clusters=n_clusters, random_state=seed).fit_predict(emb)
|
142 |
else:
|
143 |
labels = DBSCAN(eps=eps).fit_predict(emb)
|
144 |
df["cluster"] = labels.astype(str)
|
145 |
-
fig = px.scatter(df, x="x", y="y", color="cluster",
|
146 |
-
title=f"{algo} embedding with {cluster_algo}", width=700, height=500)
|
147 |
-
else:
|
148 |
-
fig = px.scatter(df, x="x", y="y",
|
149 |
-
title=f"{algo} embedding", width=700, height=500)
|
150 |
|
151 |
-
|
152 |
-
|
|
|
|
|
|
|
153 |
|
154 |
-
|
155 |
-
st.subheader("2-D Embedding")
|
156 |
-
st.plotly_chart(fig, use_container_width=True)
|
157 |
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
st.markdown(
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
|
|
|
|
|
|
|
1 |
import io
|
2 |
import textwrap
|
3 |
import itertools
|
4 |
+
import zipfile
|
5 |
+
from typing import List, Tuple
|
6 |
|
7 |
import numpy as np
|
8 |
import pandas as pd
|
|
|
12 |
from sklearn.cluster import KMeans, DBSCAN
|
13 |
import umap.umap_ as umap
|
14 |
import plotly.express as px
|
15 |
+
from scipy.spatial.distance import cdist
|
16 |
from sklearn.datasets import make_swiss_roll
|
17 |
|
18 |
+
# ββ Example shapes (some generated on demand) ββββββββββββββββββββββββββββββββ
|
19 |
def generate_hypercube(n=4):
|
20 |
return np.array(list(itertools.product([0, 1], repeat=n)), dtype=float)
|
21 |
|
|
|
29 |
return X
|
30 |
|
31 |
EXAMPLE_SHAPES = {
|
32 |
+
"Cube (3βD, 8 pts)": np.array([
|
33 |
[0,0,0],[0,0,1],[0,1,0],[0,1,1],
|
34 |
[1,0,0],[1,0,1],[1,1,0],[1,1,1]
|
35 |
], dtype=float),
|
36 |
+
"Square pyramid (3βD, 5 pts)": np.array([
|
37 |
[-1,-1,0],[1,-1,0],[1,1,0],[-1,1,0],[0,0,1]
|
38 |
], dtype=float),
|
39 |
+
"4βD hypercube (16 pts)": generate_hypercube(4),
|
40 |
+
"3βsimplex (4 pts in 3βD)": generate_simplex(3),
|
41 |
+
"Swiss roll (500 pts, 3βD)": generate_swiss_roll,
|
42 |
}
|
43 |
|
44 |
+
# ββ Helpers ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
45 |
def parse_text_points(text: str) -> np.ndarray:
|
46 |
txt = textwrap.dedent(text.strip())
|
47 |
rows = [r for r in txt.splitlines() if r.strip()]
|
|
|
62 |
min_dist=min_dist, random_state=seed)
|
63 |
return um.fit_transform(data), None
|
64 |
|
65 |
+
# ββ Streamlit UI βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
66 |
st.set_page_config(layout="wide")
|
67 |
st.title("π Dimensionality Reduction Explorer")
|
68 |
+
|
69 |
st.write("""
|
70 |
+
Upload **one or many** CSV/TXT files *or* use the other sources, pick an algorithm,
|
71 |
+
(optionally cluster), and explore the 2βD embedding. Each result is downloadable
|
72 |
+
with a full pairβwise distance table.
|
73 |
""")
|
74 |
|
75 |
+
# Sidebar ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
76 |
with st.sidebar:
|
77 |
st.header("1οΈβ£ Data Input")
|
78 |
+
mode = st.radio("Source", ["Example shape", "Upload CSV/TXT", "Paste text"])
|
79 |
+
|
80 |
+
datasets: List[Tuple[str, np.ndarray]] = []
|
81 |
+
|
82 |
if mode == "Example shape":
|
83 |
key = st.selectbox("Choose example", list(EXAMPLE_SHAPES.keys()))
|
84 |
src = EXAMPLE_SHAPES[key]
|
85 |
data_raw = src() if callable(src) else src
|
86 |
+
datasets.append((key.replace(" ", "_"), data_raw))
|
87 |
+
|
88 |
elif mode == "Upload CSV/TXT":
|
89 |
+
uploads = st.file_uploader(
|
90 |
+
"Upload one **or many** files",
|
91 |
+
type=["csv", "txt"],
|
92 |
+
accept_multiple_files=True
|
93 |
+
)
|
94 |
+
if not uploads:
|
95 |
st.stop()
|
96 |
+
for up in uploads:
|
97 |
+
txt = io.StringIO(up.getvalue().decode("utf-8")).read()
|
98 |
+
pts = parse_text_points(txt)
|
99 |
+
datasets.append((up.name.rsplit(".", 1)[0], pts))
|
100 |
+
else: # Paste text
|
101 |
placeholder = "e.g.\n0,0,0\n0,0,1\n0,1,0\n..."
|
102 |
txt = st.text_area("Paste coordinates", height=200, placeholder=placeholder)
|
103 |
if not txt.strip():
|
104 |
st.stop()
|
105 |
data_raw = parse_text_points(txt)
|
106 |
+
datasets.append(("pasted_points", data_raw))
|
107 |
|
108 |
st.header("2οΈβ£ Algorithm & Params")
|
109 |
+
algo = st.selectbox("Method", ["tβSNE", "PCA", "UMAP"])
|
110 |
seed = st.number_input("Random seed", value=42, step=1)
|
111 |
|
112 |
+
if algo == "tβSNE":
|
113 |
perp = st.slider("Perplexity", 1.0, 50.0, 30.0, 1.0)
|
114 |
elif algo == "UMAP":
|
115 |
neighbors = st.slider("n_neighbors", 5, 200, 15, 5)
|
|
|
118 |
st.header("3οΈβ£ Clustering (optional)")
|
119 |
do_cluster = st.checkbox("Cluster embedding")
|
120 |
if do_cluster:
|
121 |
+
cluster_algo = st.selectbox("Algorithm", ["KMeans", "DBSCAN"])
|
122 |
if cluster_algo == "KMeans":
|
123 |
n_clusters = st.slider("n_clusters", 2, 10, 3, 1)
|
124 |
else:
|
|
|
127 |
st.markdown("---")
|
128 |
run = st.button("Run & Visualize π")
|
129 |
|
130 |
+
# ββ Main processing βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
131 |
+
def process_dataset(name: str, pts: np.ndarray):
|
|
|
132 |
if pts.ndim != 2 or pts.shape[0] < 2:
|
133 |
+
st.error(f"Dataset **{name}** needs at least two points in an (n_pts Γ n_dims) array.")
|
134 |
+
return None, None
|
135 |
|
136 |
+
# Dimensionality reduction
|
137 |
+
if algo == "tβSNE":
|
138 |
emb, kl = run_tsne(pts, perp, seed)
|
139 |
elif algo == "PCA":
|
140 |
emb, kl = run_pca(pts)
|
141 |
else:
|
142 |
emb, kl = run_umap(pts, neighbors, min_dist, seed)
|
143 |
|
144 |
+
# Trustworthiness
|
145 |
n_samples = pts.shape[0]
|
146 |
k_max = (n_samples - 1) // 2
|
147 |
+
tw = trustworthiness(pts, emb, n_neighbors=k_max) if k_max >= 1 else None
|
|
|
|
|
|
|
148 |
|
149 |
+
# DataFrame for embedding
|
150 |
+
df = pd.DataFrame(emb, columns=["x", "y"])
|
151 |
+
|
152 |
+
# Clustering
|
153 |
if do_cluster:
|
154 |
if cluster_algo == "KMeans":
|
155 |
labels = KMeans(n_clusters=n_clusters, random_state=seed).fit_predict(emb)
|
156 |
else:
|
157 |
labels = DBSCAN(eps=eps).fit_predict(emb)
|
158 |
df["cluster"] = labels.astype(str)
|
|
|
|
|
|
|
|
|
|
|
159 |
|
160 |
+
# Pairβwise distances in embedding
|
161 |
+
dist_matrix = cdist(emb, emb, metric="euclidean")
|
162 |
+
dist_df = pd.DataFrame(dist_matrix,
|
163 |
+
columns=[f"dist_{i}" for i in range(n_samples)])
|
164 |
+
out_df = pd.concat([df, dist_df], axis=1)
|
165 |
|
166 |
+
return out_df, {"kl": kl, "tw": tw, "k_max": k_max}
|
|
|
|
|
167 |
|
168 |
+
if run:
|
169 |
+
results: List[Tuple[str, pd.DataFrame]] = []
|
170 |
+
|
171 |
+
for name, pts in datasets:
|
172 |
+
st.subheader(f"π Dataset: {name}")
|
173 |
+
out_df, stats = process_dataset(name, pts)
|
174 |
+
if out_df is None:
|
175 |
+
continue
|
176 |
+
|
177 |
+
# Scatter plot
|
178 |
+
color_arg = "cluster" if ("cluster" in out_df.columns) else None
|
179 |
+
fig = px.scatter(out_df, x="x", y="y", color=color_arg,
|
180 |
+
title=f"{algo} embedding ({name})",
|
181 |
+
width=700, height=500)
|
182 |
+
fig.update_traces(marker=dict(size=8))
|
183 |
+
fig.update_layout(margin=dict(l=20, r=20, t=40, b=20))
|
184 |
+
st.plotly_chart(fig, use_container_width=True)
|
185 |
+
|
186 |
+
# Stats
|
187 |
+
if stats["tw"] is not None:
|
188 |
+
st.markdown(f"**Trustworthiness (k={stats['k_max']}):** {stats['tw']:.3f}")
|
189 |
+
else:
|
190 |
+
st.markdown("**Trustworthiness:** Not enough samples to compute (need β₯β―3 points).")
|
191 |
+
if stats["kl"] is not None:
|
192 |
+
st.markdown(f"**tβSNE KL divergence:** {stats['kl']:.3f}")
|
193 |
+
|
194 |
+
# Distance matrix preview
|
195 |
+
with st.expander("π Show pairβwise distance matrix"):
|
196 |
+
st.dataframe(out_df.filter(like="dist_"))
|
197 |
+
|
198 |
+
# Download CSV for this dataset
|
199 |
+
csv_bytes = out_df.to_csv(index=False).encode("utfβ8")
|
200 |
+
st.download_button(
|
201 |
+
f"Download embeddingβ―+β―distances ({name})",
|
202 |
+
data=csv_bytes,
|
203 |
+
file_name=f"{name}_embedding_with_distances.csv",
|
204 |
+
mime="text/csv"
|
205 |
+
)
|
206 |
+
|
207 |
+
# Keep for ZIP if batch
|
208 |
+
results.append((name, csv_bytes))
|
209 |
+
|
210 |
+
# Oneβclick ZIP if multiple datasets
|
211 |
+
if len(results) >= 2:
|
212 |
+
zip_buf = io.BytesIO()
|
213 |
+
with zipfile.ZipFile(zip_buf, "w", zipfile.ZIP_DEFLATED) as zf:
|
214 |
+
for nm, csv_b in results:
|
215 |
+
zf.writestr(f"{nm}_embedding_with_distances.csv", csv_b)
|
216 |
+
st.download_button(
|
217 |
+
"π¦ Download **all** results as ZIP",
|
218 |
+
data=zip_buf.getvalue(),
|
219 |
+
file_name="all_embeddings_with_distances.zip",
|
220 |
+
mime="application/zip"
|
221 |
+
)
|