|
import os |
|
import pickle |
|
import streamlit as st |
|
import pandas as pd |
|
import vec2text |
|
import torch |
|
from transformers import AutoModel, AutoTokenizer |
|
from umap import UMAP |
|
from tqdm import tqdm |
|
import plotly.express as px |
|
import numpy as np |
|
from sklearn.decomposition import PCA |
|
from streamlit_plotly_events import plotly_events |
|
import plotly.graph_objects as go |
|
import logging |
|
|
|
|
|
tqdm.pandas() |
|
|
|
|
|
def file_cache(file_path): |
|
def decorator(func): |
|
def wrapper(*args, **kwargs): |
|
|
|
if os.path.exists(file_path): |
|
|
|
with open(file_path, "rb") as f: |
|
print(f"Loading cached data from {file_path}") |
|
return pickle.load(f) |
|
else: |
|
|
|
result = func(*args, **kwargs) |
|
with open(file_path, "wb") as f: |
|
pickle.dump(result, f) |
|
print(f"Saving new cache to {file_path}") |
|
return result |
|
return wrapper |
|
return decorator |
|
|
|
@st.cache_resource |
|
def vector_compressor_from_config(): |
|
|
|
return UMAP(n_components=2) |
|
|
|
|
|
@st.cache_data |
|
def load_data(): |
|
return pd.read_csv("https://huggingface.co/datasets/marksverdhei/reddit-syac-urls/resolve/main/train.csv") |
|
|
|
df = load_data() |
|
|
|
|
|
@st.cache_resource |
|
def load_model_and_tokenizer(): |
|
encoder = AutoModel.from_pretrained("sentence-transformers/gtr-t5-base").encoder.to("cuda") |
|
tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/gtr-t5-base") |
|
return encoder, tokenizer |
|
|
|
encoder, tokenizer = load_model_and_tokenizer() |
|
|
|
|
|
@st.cache_resource |
|
def load_corrector(): |
|
return vec2text.load_pretrained_corrector("gtr-base") |
|
|
|
corrector = load_corrector() |
|
|
|
|
|
@st.cache_data |
|
def load_embeddings(): |
|
return np.load("syac-title-embeddings.npy") |
|
|
|
embeddings = load_embeddings() |
|
|
|
|
|
@st.cache_data |
|
@file_cache(".cache/reducer_embeddings.pickle") |
|
def reduce_embeddings(embeddings): |
|
reducer = vector_compressor_from_config() |
|
return reducer.fit_transform(embeddings), reducer |
|
|
|
vectors_2d, reducer = reduce_embeddings(embeddings) |
|
|
|
|
|
fig = px.scatter( |
|
x=vectors_2d[:, 0], |
|
y=vectors_2d[:, 1], |
|
opacity=0.6, |
|
hover_data={"Title": df["title"]}, |
|
labels={'x': 'UMAP Dimension 1', 'y': 'UMAP Dimension 2'}, |
|
title="UMAP Scatter Plot of Reddit Titles", |
|
color_discrete_sequence=["#ff504c"] |
|
) |
|
|
|
|
|
fig.update_layout( |
|
template=None, |
|
plot_bgcolor="rgba(0, 0, 0, 0)", |
|
paper_bgcolor="rgba(0, 0, 0, 0)" |
|
) |
|
|
|
x, y = 0.0, 0.0 |
|
|
|
|
|
selected_points = plotly_events(fig, click_event=True, hover_event=False, override_height=600, override_width="100%") |
|
|
|
|
|
st.sidebar.header("Scatter Plot Info") |
|
st.sidebar.write(""" |
|
This scatter plot visualizes the UMAP dimensionality reduction of Reddit post titles. |
|
Each point represents a post, with similar titles being positioned closer together. |
|
""") |
|
st.sidebar.write("Use the form below to select coordinates or click on a point in the scatter plot.") |
|
st.sidebar.markdown("---") |
|
|
|
st.sidebar.header("How to Use") |
|
st.sidebar.write(""" |
|
1. **Click a point** in the scatter plot to see the corresponding coordinates. |
|
2. **Adjust the coordinates** using the form inputs if needed. |
|
3. **Submit** to see the reconstructed text output. |
|
""") |
|
|
|
|
|
with st.form(key="form1"): |
|
|
|
if selected_points: |
|
clicked_point = selected_points[0] |
|
x_coord = x = clicked_point['x'] |
|
y_coord = y = clicked_point['y'] |
|
|
|
x = st.number_input("X Coordinate", value=x, format="%.10f") |
|
y = st.number_input("Y Coordinate", value=y, format="%.10f") |
|
|
|
submit_button = st.form_submit_button("Submit") |
|
|
|
if selected_points or submit_button: |
|
inferred_embedding = reducer.inverse_transform(np.array([[x, y]]) if not isinstance(reducer, UMAP) else np.array([[x, y]])) |
|
inferred_embedding = inferred_embedding.astype("float32") |
|
|
|
output = vec2text.invert_embeddings( |
|
embeddings=torch.tensor(inferred_embedding).cuda(), |
|
corrector=corrector, |
|
num_steps=20, |
|
) |
|
|
|
st.text(str(output)) |
|
st.text(str(inferred_embedding)) |
|
else: |
|
st.text("Click on a point in the scatterplot to see its coordinates.") |
|
|