euler314 commited on
Commit
d7768bb
Β·
verified Β·
1 Parent(s): 599c56e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +93 -90
app.py CHANGED
@@ -1,21 +1,21 @@
1
  import io
2
- import textwrap
3
  import itertools
 
4
  import zipfile
5
  from typing import List, Tuple
6
 
7
  import numpy as np
8
  import pandas as pd
9
- import streamlit as st
10
- from sklearn.manifold import TSNE, trustworthiness
11
- from sklearn.decomposition import PCA
12
- from sklearn.cluster import KMeans, DBSCAN
13
- import umap.umap_ as umap
14
  import plotly.express as px
 
15
  from scipy.spatial.distance import cdist
 
 
16
  from sklearn.datasets import make_swiss_roll
 
 
17
 
18
- # ── Example shapes (some generated on demand) ────────────────────────────────
19
  def generate_hypercube(n=4):
20
  return np.array(list(itertools.product([0, 1], repeat=n)), dtype=float)
21
 
@@ -29,16 +29,16 @@ def generate_swiss_roll(n_samples=500, noise=0.05):
29
  return X
30
 
31
  EXAMPLE_SHAPES = {
32
- "Cube (3‑D, 8 pts)": np.array([
33
  [0,0,0],[0,0,1],[0,1,0],[0,1,1],
34
  [1,0,0],[1,0,1],[1,1,0],[1,1,1]
35
  ], dtype=float),
36
- "Square pyramid (3‑D, 5 pts)": np.array([
37
  [-1,-1,0],[1,-1,0],[1,1,0],[-1,1,0],[0,0,1]
38
  ], dtype=float),
39
- "4‑D hypercube (16 pts)": generate_hypercube(4),
40
- "3‑simplex (4 pts in 3‑D)": generate_simplex(3),
41
- "Swiss roll (500 pts, 3‑D)": generate_swiss_roll,
42
  }
43
 
44
  # ── Helpers ──────────────────────────────────────────────────────────────────
@@ -62,17 +62,24 @@ def run_umap(data, n_neighbors, min_dist, seed):
62
  min_dist=min_dist, random_state=seed)
63
  return um.fit_transform(data), None
64
 
 
 
 
 
 
 
65
  # ── Streamlit UI ─────────────────────────────────────────────────────────────
66
  st.set_page_config(layout="wide")
67
  st.title("πŸŒ€ Dimensionality Reduction Explorer")
68
 
69
  st.write("""
70
- Upload **one or many** CSV/TXT files *or* use the other sources, pick an algorithm,
71
- (optionally cluster), and explore the 2‑D embedding. Each result is downloadable
72
- with a full pair‑wise distance table.
 
73
  """)
74
 
75
- # Sidebar ────────────────────────────────────────────────────────────────────
76
  with st.sidebar:
77
  st.header("1️⃣ Data Input")
78
  mode = st.radio("Source", ["Example shape", "Upload CSV/TXT", "Paste text"])
@@ -87,9 +94,7 @@ with st.sidebar:
87
 
88
  elif mode == "Upload CSV/TXT":
89
  uploads = st.file_uploader(
90
- "Upload one **or many** files",
91
- type=["csv", "txt"],
92
- accept_multiple_files=True
93
  )
94
  if not uploads:
95
  st.stop()
@@ -102,14 +107,13 @@ with st.sidebar:
102
  txt = st.text_area("Paste coordinates", height=200, placeholder=placeholder)
103
  if not txt.strip():
104
  st.stop()
105
- data_raw = parse_text_points(txt)
106
- datasets.append(("pasted_points", data_raw))
107
 
108
  st.header("2️⃣ Algorithm & Params")
109
- algo = st.selectbox("Method", ["t‑SNE", "PCA", "UMAP"])
110
  seed = st.number_input("Random seed", value=42, step=1)
111
 
112
- if algo == "t‑SNE":
113
  perp = st.slider("Perplexity", 1.0, 50.0, 30.0, 1.0)
114
  elif algo == "UMAP":
115
  neighbors = st.slider("n_neighbors", 5, 200, 15, 5)
@@ -128,28 +132,30 @@ with st.sidebar:
128
  run = st.button("Run & Visualize πŸš€")
129
 
130
  # ── Main processing ─────────────────────────────────────────────────────────
131
- def process_dataset(name: str, pts: np.ndarray):
132
  if pts.ndim != 2 or pts.shape[0] < 2:
133
- st.error(f"Dataset **{name}** needs at least two points in an (n_pts Γ— n_dims) array.")
134
  return None, None
135
 
136
- # Dimensionality reduction
137
- if algo == "t‑SNE":
138
  emb, kl = run_tsne(pts, perp, seed)
139
  elif algo == "PCA":
140
  emb, kl = run_pca(pts)
141
  else:
142
  emb, kl = run_umap(pts, neighbors, min_dist, seed)
143
 
144
- # Trustworthiness
145
  n_samples = pts.shape[0]
146
  k_max = (n_samples - 1) // 2
147
  tw = trustworthiness(pts, emb, n_neighbors=k_max) if k_max >= 1 else None
148
 
149
- # DataFrame for embedding
150
- df = pd.DataFrame(emb, columns=["x", "y"])
 
 
151
 
152
- # Clustering
153
  if do_cluster:
154
  if cluster_algo == "KMeans":
155
  labels = KMeans(n_clusters=n_clusters, random_state=seed).fit_predict(emb)
@@ -157,65 +163,62 @@ def process_dataset(name: str, pts: np.ndarray):
157
  labels = DBSCAN(eps=eps).fit_predict(emb)
158
  df["cluster"] = labels.astype(str)
159
 
160
- # Pair‑wise distances in embedding
161
- dist_matrix = cdist(emb, emb, metric="euclidean")
162
- dist_df = pd.DataFrame(dist_matrix,
163
- columns=[f"dist_{i}" for i in range(n_samples)])
164
- out_df = pd.concat([df, dist_df], axis=1)
165
 
166
- return out_df, {"kl": kl, "tw": tw, "k_max": k_max}
 
167
 
168
- if run:
169
- results: List[Tuple[str, pd.DataFrame]] = []
170
-
171
- for name, pts in datasets:
172
- st.subheader(f"πŸ“‚ Dataset: {name}")
173
- out_df, stats = process_dataset(name, pts)
174
- if out_df is None:
175
- continue
176
-
177
- # Scatter plot
178
- color_arg = "cluster" if ("cluster" in out_df.columns) else None
179
- fig = px.scatter(out_df, x="x", y="y", color=color_arg,
180
- title=f"{algo} embedding ({name})",
181
- width=700, height=500)
182
- fig.update_traces(marker=dict(size=8))
183
- fig.update_layout(margin=dict(l=20, r=20, t=40, b=20))
184
- st.plotly_chart(fig, use_container_width=True)
185
-
186
- # Stats
187
- if stats["tw"] is not None:
188
- st.markdown(f"**Trustworthiness (k={stats['k_max']}):** {stats['tw']:.3f}")
189
- else:
190
- st.markdown("**Trustworthiness:** Not enough samples to compute (need β‰₯β€―3 points).")
191
- if stats["kl"] is not None:
192
- st.markdown(f"**t‑SNE KL divergence:** {stats['kl']:.3f}")
193
-
194
- # Distance matrix preview
195
- with st.expander("πŸ” Show pair‑wise distance matrix"):
196
- st.dataframe(out_df.filter(like="dist_"))
197
-
198
- # Download CSV for this dataset
199
- csv_bytes = out_df.to_csv(index=False).encode("utf‑8")
200
- st.download_button(
201
- f"Download embeddingβ€―+β€―distances ({name})",
202
- data=csv_bytes,
203
- file_name=f"{name}_embedding_with_distances.csv",
204
- mime="text/csv"
205
- )
206
 
207
- # Keep for ZIP if batch
208
- results.append((name, csv_bytes))
209
-
210
- # One‑click ZIP if multiple datasets
211
- if len(results) >= 2:
212
- zip_buf = io.BytesIO()
213
- with zipfile.ZipFile(zip_buf, "w", zipfile.ZIP_DEFLATED) as zf:
214
- for nm, csv_b in results:
215
- zf.writestr(f"{nm}_embedding_with_distances.csv", csv_b)
216
- st.download_button(
217
- "πŸ“¦ Download **all** results as ZIP",
218
- data=zip_buf.getvalue(),
219
- file_name="all_embeddings_with_distances.zip",
220
- mime="application/zip"
221
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import io
 
2
  import itertools
3
+ import textwrap
4
  import zipfile
5
  from typing import List, Tuple
6
 
7
  import numpy as np
8
  import pandas as pd
 
 
 
 
 
9
  import plotly.express as px
10
+ import streamlit as st
11
  from scipy.spatial.distance import cdist
12
+ from sklearn.cluster import DBSCAN, KMeans
13
+ from sklearn.decomposition import PCA
14
  from sklearn.datasets import make_swiss_roll
15
+ from sklearn.manifold import TSNE, trustworthiness
16
+ import umap.umap_ as umap
17
 
18
+ # ── Example shapes ───────────────────────────────────────────────────────────
19
  def generate_hypercube(n=4):
20
  return np.array(list(itertools.product([0, 1], repeat=n)), dtype=float)
21
 
 
29
  return X
30
 
31
  EXAMPLE_SHAPES = {
32
+ "Cube (3-D, 8 pts)": np.array([
33
  [0,0,0],[0,0,1],[0,1,0],[0,1,1],
34
  [1,0,0],[1,0,1],[1,1,0],[1,1,1]
35
  ], dtype=float),
36
+ "Square pyramid (3-D, 5 pts)": np.array([
37
  [-1,-1,0],[1,-1,0],[1,1,0],[-1,1,0],[0,0,1]
38
  ], dtype=float),
39
+ "4-D hypercube (16 pts)": generate_hypercube(4),
40
+ "3-simplex (4 pts in 3-D)": generate_simplex(3),
41
+ "Swiss roll (500 pts, 3-D)": generate_swiss_roll,
42
  }
43
 
44
  # ── Helpers ──────────────────────────────────────────────────────────────────
 
62
  min_dist=min_dist, random_state=seed)
63
  return um.fit_transform(data), None
64
 
65
+ def distinct_count(dist_row: np.ndarray, tol: float = 1e-3) -> int:
66
+ """Count unique non-zero distances in a row after rounding to 3 decimals."""
67
+ nz = dist_row[dist_row > tol]
68
+ rounded = (nz * 1000).round().astype(int) # rounding to 3 d.p.
69
+ return len(np.unique(rounded))
70
+
71
  # ── Streamlit UI ─────────────────────────────────────────────────────────────
72
  st.set_page_config(layout="wide")
73
  st.title("πŸŒ€ Dimensionality Reduction Explorer")
74
 
75
  st.write("""
76
+ Upload **one or many** CSV/TXT files *or* use an example shape, pick an algorithm,
77
+ (optionally cluster), and explore the 2-D embedding.
78
+ Every output CSV now contains the embedding, the original point coordinates,
79
+ all pair-wise distances, **and** the number of distinct distances per point.
80
  """)
81
 
82
+ # ── Sidebar ──────────────────────────────────────────────────────────────────
83
  with st.sidebar:
84
  st.header("1️⃣ Data Input")
85
  mode = st.radio("Source", ["Example shape", "Upload CSV/TXT", "Paste text"])
 
94
 
95
  elif mode == "Upload CSV/TXT":
96
  uploads = st.file_uploader(
97
+ "Upload file(s)", type=["csv", "txt"], accept_multiple_files=True
 
 
98
  )
99
  if not uploads:
100
  st.stop()
 
107
  txt = st.text_area("Paste coordinates", height=200, placeholder=placeholder)
108
  if not txt.strip():
109
  st.stop()
110
+ datasets.append(("pasted_points", parse_text_points(txt)))
 
111
 
112
  st.header("2️⃣ Algorithm & Params")
113
+ algo = st.selectbox("Method", ["t-SNE", "PCA", "UMAP"])
114
  seed = st.number_input("Random seed", value=42, step=1)
115
 
116
+ if algo == "t-SNE":
117
  perp = st.slider("Perplexity", 1.0, 50.0, 30.0, 1.0)
118
  elif algo == "UMAP":
119
  neighbors = st.slider("n_neighbors", 5, 200, 15, 5)
 
132
  run = st.button("Run & Visualize πŸš€")
133
 
134
  # ── Main processing ─────────────────────────────────────────────────────────
135
+ def process_dataset(name: str, pts: np.ndarray) -> Tuple[pd.DataFrame, dict]:
136
  if pts.ndim != 2 or pts.shape[0] < 2:
137
+ st.error(f"Dataset **{name}** needs at least two points (rows).")
138
  return None, None
139
 
140
+ # 1. Reduce dimensionality
141
+ if algo == "t-SNE":
142
  emb, kl = run_tsne(pts, perp, seed)
143
  elif algo == "PCA":
144
  emb, kl = run_pca(pts)
145
  else:
146
  emb, kl = run_umap(pts, neighbors, min_dist, seed)
147
 
148
+ # 2. Trustworthiness
149
  n_samples = pts.shape[0]
150
  k_max = (n_samples - 1) // 2
151
  tw = trustworthiness(pts, emb, n_neighbors=k_max) if k_max >= 1 else None
152
 
153
+ # 3. Build DataFrame in requested column order
154
+ df_emb = pd.DataFrame(emb, columns=["x", "y"])
155
+ df_pts = pd.DataFrame(pts, columns=[f"p{i}" for i in range(pts.shape[1])])
156
+ df = pd.concat([df_emb, df_pts], axis=1)
157
 
158
+ # 4. Clustering (optional)
159
  if do_cluster:
160
  if cluster_algo == "KMeans":
161
  labels = KMeans(n_clusters=n_clusters, random_state=seed).fit_predict(emb)
 
163
  labels = DBSCAN(eps=eps).fit_predict(emb)
164
  df["cluster"] = labels.astype(str)
165
 
166
+ # 5. Pair-wise distances (in embedding space)
167
+ dists = cdist(emb, emb, metric="euclidean")
168
+ dist_df = pd.DataFrame(dists, columns=[f"dist_{i}" for i in range(n_samples)])
169
+ df = pd.concat([df, dist_df], axis=1)
 
170
 
171
+ # 6. Distinct-distance count per point
172
+ df["distinct_count"] = [distinct_count(row) for row in dists]
173
 
174
+ return df, {"kl": kl, "tw": tw, "k_max": k_max}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
175
 
176
+ if run:
177
+ zip_buffer = io.BytesIO()
178
+ with zipfile.ZipFile(zip_buffer, "w", zipfile.ZIP_DEFLATED) as zf:
179
+ for name, pts in datasets:
180
+ st.subheader(f"πŸ“‚ Dataset: {name}")
181
+ out_df, stats = process_dataset(name, pts)
182
+ if out_df is None:
183
+ continue
184
+
185
+ # Plot
186
+ color_col = "cluster" if "cluster" in out_df.columns else None
187
+ fig = px.scatter(out_df, x="x", y="y", color=color_col,
188
+ title=f"{algo} embedding ({name})",
189
+ width=700, height=500)
190
+ fig.update_traces(marker=dict(size=8))
191
+ fig.update_layout(margin=dict(l=20, r=20, t=40, b=20))
192
+ st.plotly_chart(fig, use_container_width=True)
193
+
194
+ # Stats
195
+ if stats["tw"] is not None:
196
+ st.markdown(f"**Trustworthiness (k={stats['k_max']}):** {stats['tw']:.3f}")
197
+ else:
198
+ st.markdown("**Trustworthiness:** Not enough samples to compute.")
199
+ if stats["kl"] is not None:
200
+ st.markdown(f"**t-SNE KL divergence:** {stats['kl']:.3f}")
201
+
202
+ # Data preview
203
+ with st.expander("Preview first 10 rows"):
204
+ st.dataframe(out_df.head(10))
205
+
206
+ # Individual CSV download
207
+ csv_bytes = out_df.to_csv(index=False).encode("utf-8")
208
+ st.download_button(
209
+ f"Download CSV ({name})",
210
+ data=csv_bytes,
211
+ file_name=f"{name}_embedding_with_distances.csv",
212
+ mime="text/csv"
213
+ )
214
+
215
+ # Add to ZIP
216
+ zf.writestr(f"{name}_embedding_with_distances.csv", csv_bytes)
217
+
218
+ # ZIP download (always available once run)
219
+ st.download_button(
220
+ "πŸ“¦ Download ALL results as ZIP",
221
+ data=zip_buffer.getvalue(),
222
+ file_name="all_embeddings_with_distances.zip",
223
+ mime="application/zip"
224
+ )