|
import streamlit as st |
|
import pandas as pd |
|
import vec2text |
|
import torch |
|
from transformers import AutoModel, AutoTokenizer, PreTrainedTokenizer, PreTrainedModel |
|
from umap import UMAP |
|
from tqdm import tqdm |
|
import plotly.express as px |
|
import numpy as np |
|
|
|
|
|
tqdm.pandas() |
|
|
|
|
|
@st.cache_data |
|
def load_data(): |
|
return pd.read_csv("https://huggingface.co/datasets/marksverdhei/reddit-syac-urls/resolve/main/train.csv") |
|
|
|
df = load_data() |
|
|
|
|
|
@st.cache_resource |
|
def load_model_and_tokenizer(): |
|
encoder = AutoModel.from_pretrained("sentence-transformers/gtr-t5-base").encoder.to("cuda") |
|
tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/gtr-t5-base") |
|
return encoder, tokenizer |
|
|
|
encoder, tokenizer = load_model_and_tokenizer() |
|
|
|
|
|
@st.cache_resource |
|
def load_corrector(): |
|
return vec2text.load_pretrained_corrector("gtr-base") |
|
|
|
corrector = load_corrector() |
|
|
|
|
|
@st.cache_data |
|
def load_embeddings(): |
|
return np.load("syac-title-embeddings.npy") |
|
|
|
embeddings = load_embeddings() |
|
|
|
|
|
@st.cache_resource |
|
def reduce_embeddings(embeddings): |
|
reducer = UMAP() |
|
return reducer.fit_transform(embeddings), reducer |
|
|
|
vectors_2d, reducer = reduce_embeddings(embeddings) |
|
|
|
|
|
fig = px.scatter( |
|
x=vectors_2d[:, 0], |
|
y=vectors_2d[:, 1], |
|
opacity=0.4, |
|
hover_data={"Title": df["title"]}, |
|
labels={'x': 'UMAP Dimension 1', 'y': 'UMAP Dimension 2'}, |
|
title="UMAP Scatter Plot of Reddit Titles" |
|
) |
|
|
|
|
|
st.plotly_chart(fig) |
|
|
|
|
|
with st.form(key="form1"): |
|
x = st.number_input("Input X coordinate") |
|
y = st.number_input("Input Y coordinate") |
|
submit_button = st.form_submit_button("Submit") |
|
|
|
if submit_button: |
|
inferred_embedding = reducer.inverse_transform([[x, y]]) |
|
output = vec2text.invert_embeddings( |
|
embeddings=torch.tensor(inferred_embedding).cuda(), |
|
corrector=corrector, |
|
num_steps=20, |
|
) |
|
st.text(str(output)) |
|
st.text(str(inferred_embedding)) |
|
|