marksverdhei's picture
Add app
8aa44e7
raw
history blame
2.39 kB
import streamlit as st
import pandas as pd
import vec2text
import torch
from transformers import AutoModel, AutoTokenizer, PreTrainedTokenizer, PreTrainedModel
from umap import UMAP
from tqdm import tqdm
import plotly.express as px
import numpy as np
# Activate tqdm with pandas
tqdm.pandas()
# Caching the dataframe since loading from external source can be time-consuming
@st.cache_data
def load_data():
return pd.read_csv("https://huggingface.co/datasets/marksverdhei/reddit-syac-urls/resolve/main/train.csv")
df = load_data()
# Caching the model and tokenizer to avoid reloading
@st.cache_resource
def load_model_and_tokenizer():
encoder = AutoModel.from_pretrained("sentence-transformers/gtr-t5-base").encoder.to("cuda")
tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/gtr-t5-base")
return encoder, tokenizer
encoder, tokenizer = load_model_and_tokenizer()
# Caching the vec2text corrector
@st.cache_resource
def load_corrector():
return vec2text.load_pretrained_corrector("gtr-base")
corrector = load_corrector()
# Caching the precomputed embeddings since they are stored locally and large
@st.cache_data
def load_embeddings():
return np.load("syac-title-embeddings.npy")
embeddings = load_embeddings()
# Caching UMAP reduction as it's a heavy computation
@st.cache_resource
def reduce_embeddings(embeddings):
reducer = UMAP()
return reducer.fit_transform(embeddings), reducer
vectors_2d, reducer = reduce_embeddings(embeddings)
# Add a scatter plot using Plotly
fig = px.scatter(
x=vectors_2d[:, 0],
y=vectors_2d[:, 1],
opacity=0.4,
hover_data={"Title": df["title"]},
labels={'x': 'UMAP Dimension 1', 'y': 'UMAP Dimension 2'},
title="UMAP Scatter Plot of Reddit Titles"
)
# Display plot in Streamlit
st.plotly_chart(fig)
# Streamlit form to take user inputs and handle interaction
with st.form(key="form1"):
x = st.number_input("Input X coordinate")
y = st.number_input("Input Y coordinate")
submit_button = st.form_submit_button("Submit")
if submit_button:
inferred_embedding = reducer.inverse_transform([[x, y]])
output = vec2text.invert_embeddings(
embeddings=torch.tensor(inferred_embedding).cuda(),
corrector=corrector,
num_steps=20,
)
st.text(str(output))
st.text(str(inferred_embedding))