|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import marimo |
|
|
|
__generated_with = "0.9.33" |
|
app = marimo.App(width="medium") |
|
|
|
|
|
@app.cell |
|
def __(): |
|
import marimo as mo |
|
|
|
return (mo,) |
|
|
|
|
|
@app.cell(hide_code=True) |
|
def __(mo): |
|
mo.md( |
|
r""" |
|
# Visualizing text embeddings using MotherDuck and marimo |
|
|
|
> Text embeddings have become a crucial tool in AI/ML applications, allowing us to convert text into numerical vectors that capture semantic meaning. These vectors are often used for semantic search, but in this blog post, we'll explore how to visualize and explore text embeddings interactively using MotherDuck and marimo. |
|
|
|
[_Read the full blog here._](https://motherduck.com/blog/MotherDuck-Visualize-Embeddings-Marimo/) |
|
""" |
|
) |
|
return |
|
|
|
|
|
@app.cell(hide_code=True) |
|
def __(mo): |
|
mo.md( |
|
""" |
|
## Connecting to MotherDuck and Loading Sample Data |
|
|
|
This data has already been pre-computed, but you can fork and edit this notebook to run with your own data! |
|
|
|
```sql |
|
ATTACH IF NOT EXISTS 'md:my_db' |
|
SELECT * FROM my_db.demo_with_embeddings; |
|
``` |
|
""" |
|
) |
|
return |
|
|
|
|
|
@app.cell |
|
def __(mo): |
|
_df = mo.sql( |
|
""" |
|
ATTACH IF NOT EXISTS 'md:my_db' |
|
""" |
|
) |
|
return (my_db,) |
|
|
|
|
|
@app.cell |
|
def __(mo): |
|
_df = mo.sql( |
|
""" |
|
-- Commented out as we have already run the embeddings for showcasing purposes. |
|
|
|
-- CREATE OR REPLACE TABLE my_db.demo_embedding_data AS |
|
-- SELECT DISTINCT ON (url) * -- Remove duplicate URLs |
|
-- FROM 'hf://datasets/julien040/hacker-news-posts/story.parquet' |
|
-- WHERE contains(title, 'database') -- Filter for posts about databases |
|
-- AND score > 5 -- Only include popular posts |
|
-- LIMIT 50000; |
|
""" |
|
) |
|
return |
|
|
|
|
|
@app.cell |
|
def __(demo_with_embeddings, mo, my_db): |
|
embeddings = mo.sql( |
|
f""" |
|
-- Commented out as we have already run the embeddings for showcasing purposes. |
|
-- CREATE TABLE my_db.demo_with_embeddings AS |
|
-- SELECT *, embedding(title) as text_embedding |
|
-- FROM my_db.demo_embedding_data |
|
-- LIMIT 1500; |
|
|
|
SELECT title, text_embedding, * EXCLUDE(id, title, text_embedding, comments) FROM my_db.demo_with_embeddings; |
|
""" |
|
) |
|
return (embeddings,) |
|
|
|
|
|
@app.cell |
|
def __(PCA, hdbscan, np, umap): |
|
def umap_reduce(np_array, metric="cosine"): |
|
""" |
|
Reduce the dimensionality of the embeddings to 2D using |
|
UMAP algorithm. UMAP preserves both local and global structure |
|
of the high-dimensional data. |
|
""" |
|
reducer = umap.UMAP( |
|
n_components=2, |
|
metric=metric, |
|
n_neighbors=80, |
|
min_dist=0.1, |
|
) |
|
return reducer.fit_transform(np_array) |
|
|
|
def cluster_points(np_array, min_cluster_size=4, max_cluster_size=50): |
|
""" |
|
Cluster the embeddings using HDBSCAN algorithm. |
|
We first reduce dimensionality to 50D with PCA to speed up clustering, |
|
while still preserving most of the important information. |
|
""" |
|
pca = PCA(n_components=50) |
|
np_array = pca.fit_transform(np_array) |
|
|
|
hdb = hdbscan.HDBSCAN( |
|
min_samples=3, |
|
min_cluster_size=min_cluster_size, |
|
max_cluster_size=max_cluster_size, |
|
).fit(np_array) |
|
|
|
return np.where( |
|
hdb.labels_ == -1, "outlier", "cluster_" + hdb.labels_.astype(str) |
|
) |
|
|
|
return cluster_points, umap_reduce |
|
|
|
|
|
@app.cell |
|
def __(mo): |
|
cluster_size_slider = mo.ui.range_slider( |
|
start=1, |
|
stop=80, |
|
value=(4, 50), |
|
step=1, |
|
show_value=True, |
|
debounce=True, |
|
label="Cluster Size (min, max)", |
|
) |
|
metric_dropdown = mo.ui.dropdown( |
|
["cosine", "euclidean", "manhattan", "mahalanobis"], |
|
value="cosine", |
|
label="Distance Metric", |
|
) |
|
return cluster_size_slider, metric_dropdown |
|
|
|
|
|
@app.cell |
|
def __(mo): |
|
mo.md( |
|
r""" |
|
## Processing the Data |
|
|
|
Now we'll transform our high-dimensional embeddings into something we can visualize, using `umap_reduce` and `cluster_points`. More details on this step [in the blog](https://motherduck.com/blog/MotherDuck-Visualize-Embeddings-Marimo/). |
|
""" |
|
) |
|
return |
|
|
|
|
|
@app.cell |
|
def __( |
|
cluster_points, |
|
cluster_size_slider, |
|
embeddings, |
|
metric_dropdown, |
|
mo, |
|
umap_reduce, |
|
): |
|
with mo.status.spinner("Clustering points...") as _s: |
|
embeddings_array = embeddings["text_embedding"].to_numpy() |
|
hdb_labels = cluster_points( |
|
embeddings_array, |
|
min_cluster_size=cluster_size_slider.value[0], |
|
max_cluster_size=cluster_size_slider.value[1], |
|
) |
|
_s.update("Reducing dimensionality...") |
|
embeddings_2d = umap_reduce(embeddings_array, metric=metric_dropdown.value) |
|
mo.show_code() |
|
return embeddings_2d, embeddings_array, hdb_labels |
|
|
|
|
|
@app.cell |
|
def __(cluster_size_slider, metric_dropdown, mo): |
|
mo.hstack([cluster_size_slider, metric_dropdown]) |
|
return |
|
|
|
|
|
@app.cell |
|
def __(embeddings, embeddings_2d, hdb_labels, pl): |
|
data = embeddings.lazy() |
|
data = data.with_columns( |
|
text_embedding_2d_1=embeddings_2d[:, 0], |
|
text_embedding_2d_2=embeddings_2d[:, 1], |
|
cluster=hdb_labels, |
|
) |
|
data = data.unique(subset=["url"], maintain_order=True) |
|
data = data.drop(["text_embedding"]) |
|
data = data.filter(pl.col("cluster") != "outlier") |
|
data = data.collect() |
|
return (data,) |
|
|
|
|
|
@app.cell |
|
def __(data): |
|
data.select( |
|
"title", "cluster", "text_embedding_2d_1", "text_embedding_2d_2", "score" |
|
) |
|
return |
|
|
|
|
|
@app.cell |
|
def __(alt, data, mo): |
|
chart = ( |
|
alt.Chart(data) |
|
.mark_point() |
|
.encode( |
|
x=alt.X("text_embedding_2d_1").scale(zero=False), |
|
y=alt.Y("text_embedding_2d_2").scale(zero=False), |
|
color="cluster", |
|
tooltip=["title", "score", "cluster"], |
|
) |
|
) |
|
chart = mo.ui.altair_chart(chart) |
|
mo.show_code() |
|
return (chart,) |
|
|
|
|
|
@app.cell(hide_code=True) |
|
def __(mo): |
|
mo.md( |
|
r""" |
|
## Creating an Interactive Visualization |
|
|
|
We will plot the 2D representation of the text embeddings, colored by the clusters identified by HDBSCAN. You can select points on the chart to explore the text embeddings further. 👇 |
|
""" |
|
) |
|
return |
|
|
|
|
|
@app.cell |
|
def __(chart): |
|
chart |
|
return |
|
|
|
|
|
@app.cell |
|
def __(chart): |
|
chart.value |
|
return |
|
|
|
|
|
@app.cell |
|
def __(mo): |
|
|
|
mo.Html("<div style='height: 400px;'></div>") |
|
return |
|
|
|
|
|
@app.cell |
|
def __(): |
|
|
|
import polars as pl |
|
import duckdb |
|
import numba |
|
import pyarrow |
|
|
|
|
|
import altair as alt |
|
|
|
|
|
import umap |
|
import hdbscan |
|
import numpy as np |
|
from sklearn.decomposition import PCA |
|
|
|
return PCA, alt, duckdb, hdbscan, np, numba, pl, pyarrow, umap |
|
|
|
|
|
if __name__ == "__main__": |
|
app.run() |
|
|