File size: 5,059 Bytes
2b22bff
 
8aa44e7
 
 
 
644a030
8aa44e7
 
 
 
93f5069
644a030
 
 
2b22bff
8aa44e7
 
 
2b22bff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93f5069
 
2b22bff
2063af3
93f5069
2b22bff
8aa44e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2b22bff
 
f616ac0
8aa44e7
93f5069
8aa44e7
 
 
 
 
 
 
 
644a030
8aa44e7
 
644a030
1988995
644a030
 
 
 
 
 
 
8aa44e7
0cb5b70
f616ac0
0cb5b70
644a030
 
0cb5b70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1988995
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import os
import pickle
import streamlit as st
import pandas as pd
import vec2text
import torch
from transformers import AutoModel, AutoTokenizer
from umap import UMAP
from tqdm import tqdm
import plotly.express as px
import numpy as np
from sklearn.decomposition import PCA
from streamlit_plotly_events import plotly_events
import plotly.graph_objects as go
import logging

# Activate tqdm with pandas
tqdm.pandas()

# Custom file cache decorator
def file_cache(file_path):
    def decorator(func):
        def wrapper(*args, **kwargs):
            # Check if the file already exists
            if os.path.exists(file_path):
                # Load from cache
                with open(file_path, "rb") as f:
                    print(f"Loading cached data from {file_path}")
                    return pickle.load(f)
            else:
                # Compute and save to cache
                result = func(*args, **kwargs)
                with open(file_path, "wb") as f:
                    pickle.dump(result, f)
                    print(f"Saving new cache to {file_path}")
                return result
        return wrapper
    return decorator

@st.cache_resource
def vector_compressor_from_config():
    # Return UMAP with 2 components for dimensionality reduction
    return UMAP(n_components=2)

# Caching the dataframe since loading from an external source can be time-consuming
@st.cache_data
def load_data():
    return pd.read_csv("https://huggingface.co/datasets/marksverdhei/reddit-syac-urls/resolve/main/train.csv")

df = load_data()

# Caching the model and tokenizer to avoid reloading
@st.cache_resource
def load_model_and_tokenizer():
    encoder = AutoModel.from_pretrained("sentence-transformers/gtr-t5-base").encoder.to("cuda")
    tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/gtr-t5-base")
    return encoder, tokenizer

encoder, tokenizer = load_model_and_tokenizer()

# Caching the vec2text corrector
@st.cache_resource
def load_corrector():
    return vec2text.load_pretrained_corrector("gtr-base")

corrector = load_corrector()

# Caching the precomputed embeddings since they are stored locally and large
@st.cache_data
def load_embeddings():
    return np.load("syac-title-embeddings.npy")

embeddings = load_embeddings()

# Custom cache the UMAP reduction using file_cache decorator
@st.cache_data
@file_cache(".cache/reducer_embeddings.pickle")
def reduce_embeddings(embeddings):
    reducer = vector_compressor_from_config()
    return reducer.fit_transform(embeddings), reducer

vectors_2d, reducer = reduce_embeddings(embeddings)

# Add a scatter plot using Plotly
fig = px.scatter(
    x=vectors_2d[:, 0], 
    y=vectors_2d[:, 1], 
    opacity=0.6,
    hover_data={"Title": df["title"]}, 
    labels={'x': 'UMAP Dimension 1', 'y': 'UMAP Dimension 2'},
    title="UMAP Scatter Plot of Reddit Titles",
    color_discrete_sequence=["#ff504c"]  # Set default blue color for points
)

# Customize the layout to adapt to browser settings (light/dark mode)
fig.update_layout(
    template=None,  # Let Plotly adapt automatically based on user settings
    plot_bgcolor="rgba(0, 0, 0, 0)",
    paper_bgcolor="rgba(0, 0, 0, 0)"
)

x, y = 0.0, 0.0

# Display the scatterplot and capture click events
selected_points = plotly_events(fig, click_event=True, hover_event=False, override_height=600, override_width="100%")

# Sidebar for additional information
st.sidebar.header("Scatter Plot Info")
st.sidebar.write("""
This scatter plot visualizes the UMAP dimensionality reduction of Reddit post titles. 
Each point represents a post, with similar titles being positioned closer together.
""")
st.sidebar.write("Use the form below to select coordinates or click on a point in the scatter plot.")
st.sidebar.markdown("---")

st.sidebar.header("How to Use")
st.sidebar.write("""
1. **Click a point** in the scatter plot to see the corresponding coordinates.
2. **Adjust the coordinates** using the form inputs if needed.
3. **Submit** to see the reconstructed text output.
""")

# Form for inputting coordinates
with st.form(key="form1"):
    # If a point is clicked, handle the embedding inversion
    if selected_points:
        clicked_point = selected_points[0]
        x_coord = x = clicked_point['x']
        y_coord = y = clicked_point['y']

    x = st.number_input("X Coordinate", value=x, format="%.10f")
    y = st.number_input("Y Coordinate", value=y, format="%.10f")

    submit_button = st.form_submit_button("Submit")

    if selected_points or submit_button:
        inferred_embedding = reducer.inverse_transform(np.array([[x, y]]) if not isinstance(reducer, UMAP) else np.array([[x, y]]))
        inferred_embedding = inferred_embedding.astype("float32")

        output = vec2text.invert_embeddings(
            embeddings=torch.tensor(inferred_embedding).cuda(),
            corrector=corrector,
            num_steps=20,
        )

        st.text(str(output))
        st.text(str(inferred_embedding))
    else:
        st.text("Click on a point in the scatterplot to see its coordinates.")