marksverdhei commited on
Commit
8aa44e7
·
1 Parent(s): d8081ae
Files changed (1) hide show
  1. app.py +79 -0
app.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ import vec2text
4
+ import torch
5
+ from transformers import AutoModel, AutoTokenizer, PreTrainedTokenizer, PreTrainedModel
6
+ from umap import UMAP
7
+ from tqdm import tqdm
8
+ import plotly.express as px
9
+ import numpy as np
10
+
11
+ # Activate tqdm with pandas
12
+ tqdm.pandas()
13
+
14
+ # Caching the dataframe since loading from external source can be time-consuming
15
+ @st.cache_data
16
+ def load_data():
17
+ return pd.read_csv("https://huggingface.co/datasets/marksverdhei/reddit-syac-urls/resolve/main/train.csv")
18
+
19
+ df = load_data()
20
+
21
+ # Caching the model and tokenizer to avoid reloading
22
+ @st.cache_resource
23
+ def load_model_and_tokenizer():
24
+ encoder = AutoModel.from_pretrained("sentence-transformers/gtr-t5-base").encoder.to("cuda")
25
+ tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/gtr-t5-base")
26
+ return encoder, tokenizer
27
+
28
+ encoder, tokenizer = load_model_and_tokenizer()
29
+
30
+ # Caching the vec2text corrector
31
+ @st.cache_resource
32
+ def load_corrector():
33
+ return vec2text.load_pretrained_corrector("gtr-base")
34
+
35
+ corrector = load_corrector()
36
+
37
+ # Caching the precomputed embeddings since they are stored locally and large
38
+ @st.cache_data
39
+ def load_embeddings():
40
+ return np.load("syac-title-embeddings.npy")
41
+
42
+ embeddings = load_embeddings()
43
+
44
+ # Caching UMAP reduction as it's a heavy computation
45
+ @st.cache_resource
46
+ def reduce_embeddings(embeddings):
47
+ reducer = UMAP()
48
+ return reducer.fit_transform(embeddings), reducer
49
+
50
+ vectors_2d, reducer = reduce_embeddings(embeddings)
51
+
52
+ # Add a scatter plot using Plotly
53
+ fig = px.scatter(
54
+ x=vectors_2d[:, 0],
55
+ y=vectors_2d[:, 1],
56
+ opacity=0.4,
57
+ hover_data={"Title": df["title"]},
58
+ labels={'x': 'UMAP Dimension 1', 'y': 'UMAP Dimension 2'},
59
+ title="UMAP Scatter Plot of Reddit Titles"
60
+ )
61
+
62
+ # Display plot in Streamlit
63
+ st.plotly_chart(fig)
64
+
65
+ # Streamlit form to take user inputs and handle interaction
66
+ with st.form(key="form1"):
67
+ x = st.number_input("Input X coordinate")
68
+ y = st.number_input("Input Y coordinate")
69
+ submit_button = st.form_submit_button("Submit")
70
+
71
+ if submit_button:
72
+ inferred_embedding = reducer.inverse_transform([[x, y]])
73
+ output = vec2text.invert_embeddings(
74
+ embeddings=torch.tensor(inferred_embedding).cuda(),
75
+ corrector=corrector,
76
+ num_steps=20,
77
+ )
78
+ st.text(str(output))
79
+ st.text(str(inferred_embedding))