marksverdhei commited on
Commit
44fce2f
·
1 Parent(s): b21feb2

Revert back a couple commits

Browse files
Files changed (1) hide show
  1. app.py +150 -157
app.py CHANGED
@@ -1,158 +1,151 @@
1
-
2
-
3
-
4
- # import os
5
- # import pickle
6
  import streamlit as st
7
-
8
-
9
- st.text("This is a test")
10
- # import pandas as pd
11
- # import vec2text
12
- # import torch
13
- # from transformers import AutoModel, AutoTokenizer
14
- # from umap import UMAP
15
- # from tqdm import tqdm
16
- # import plotly.express as px
17
- # import numpy as np
18
- # from sklearn.decomposition import PCA
19
- # # from streamlit_plotly_events import plotly_events
20
- # import plotly.graph_objects as go
21
- # import logging
22
- # import utils
23
- # # Activate tqdm with pandas
24
- # tqdm.pandas()
25
-
26
- # # Custom file cache decorator
27
- # def file_cache(file_path):
28
- # def decorator(func):
29
- # def wrapper(*args, **kwargs):
30
- # # Check if the file already exists
31
- # if os.path.exists(file_path):
32
- # # Load from cache
33
- # with open(file_path, "rb") as f:
34
- # print(f"Loading cached data from {file_path}")
35
- # return pickle.load(f)
36
- # else:
37
- # # Compute and save to cache
38
- # result = func(*args, **kwargs)
39
- # with open(file_path, "wb") as f:
40
- # pickle.dump(result, f)
41
- # print(f"Saving new cache to {file_path}")
42
- # return result
43
- # return wrapper
44
- # return decorator
45
-
46
- # @st.cache_resource
47
- # def vector_compressor_from_config():
48
- # # Return UMAP with 2 components for dimensionality reduction
49
- # # return UMAP(n_components=2)
50
- # return PCA(n_components=2)
51
-
52
-
53
- # # Caching the dataframe since loading from an external source can be time-consuming
54
- # @st.cache_data
55
- # def load_data():
56
- # return pd.read_csv("https://huggingface.co/datasets/marksverdhei/reddit-syac-urls/resolve/main/train.csv")
57
-
58
- # df = load_data()
59
-
60
- # # Caching the model and tokenizer to avoid reloading
61
- # # @st.cache_resource
62
- # # def load_model_and_tokenizer():
63
- # # encoder = AutoModel.from_pretrained("sentence-transformers/gtr-t5-base").encoder.to("cuda")
64
- # # tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/gtr-t5-base")
65
- # # return encoder, tokenizer
66
-
67
- # # encoder, tokenizer = load_model_and_tokenizer()
68
-
69
- # # Caching the vec2text corrector
70
- # # @st.cache_resource
71
- # # def load_corrector():
72
- # # return vec2text.load_pretrained_corrector("gtr-base")
73
-
74
- # # corrector = load_corrector()
75
-
76
- # # Caching the precomputed embeddings since they are stored locally and large
77
- # @st.cache_data
78
- # def load_embeddings():
79
- # return np.load("syac-title-embeddings.npy")
80
-
81
- # embeddings = load_embeddings()
82
-
83
- # # Custom cache the UMAP reduction using file_cache decorator
84
- # @st.cache_data
85
- # @file_cache(".cache/reducer_embeddings.pickle")
86
- # def reduce_embeddings(embeddings):
87
- # reducer = vector_compressor_from_config()
88
- # return reducer.fit_transform(embeddings), reducer
89
-
90
- # vectors_2d, reducer = reduce_embeddings(embeddings)
91
-
92
- # # Add a scatter plot using Plotly
93
- # # fig = px.scatter(
94
- # # x=vectors_2d[:, 0],
95
- # # y=vectors_2d[:, 1],
96
- # # opacity=0.6,
97
- # # hover_data={"Title": df["title"]},
98
- # # labels={'x': 'UMAP Dimension 1', 'y': 'UMAP Dimension 2'},
99
- # # title="UMAP Scatter Plot of Reddit Titles",
100
- # # color_discrete_sequence=["#ff504c"] # Set default blue color for points
101
- # # )
102
-
103
- # # # Customize the layout to adapt to browser settings (light/dark mode)
104
- # # fig.update_layout(
105
- # # template=None, # Let Plotly adapt automatically based on user settings
106
- # # plot_bgcolor="rgba(0, 0, 0, 0)",
107
- # # paper_bgcolor="rgba(0, 0, 0, 0)"
108
- # # )
109
-
110
- # x, y = 0.0, 0.0
111
- # vec = np.array([x, y]).astype("float32")
112
-
113
- # # Add a card container to the right of the content with Streamlit columns
114
- # col1, col2 = st.columns([3, 1]) # Adjusting ratio to allocate space for the card container
115
-
116
- # with col1:
117
- # # Main content stays here (scatterplot, form, etc.)
118
- # # selected_points = plotly_events(fig, click_event=True, hover_event=False,
119
- # # )
120
- # selected_points = None
121
- # with st.form(key="form1_main"):
122
- # if selected_points:
123
- # clicked_point = selected_points[0]
124
- # x_coord = x = clicked_point['x']
125
- # y_coord = y = clicked_point['y']
126
-
127
- # x = st.number_input("X Coordinate", value=x, format="%.10f")
128
- # y = st.number_input("Y Coordinate", value=y, format="%.10f")
129
- # vec = np.array([x, y]).astype("float32")
130
-
131
-
132
- # submit_button = st.form_submit_button("Submit")
133
-
134
- # if selected_points or submit_button:
135
- # inferred_embedding = reducer.inverse_transform(np.array([[x, y]]) if not isinstance(reducer, UMAP) else np.array([[x, y]]))
136
- # inferred_embedding = inferred_embedding.astype("float32")
137
-
138
- # output = vec2text.invert_embeddings(
139
- # embeddings=torch.tensor(inferred_embedding).cuda(),
140
- # corrector=corrector,
141
- # num_steps=20,
142
- # )
143
-
144
- # st.text(str(output))
145
- # st.text(str(inferred_embedding))
146
- # else:
147
- # st.text("Click on a point in the scatterplot to see its coordinates.")
148
-
149
- # with col2:
150
- # closest_sentence_index = utils.find_exact_match(vectors_2d, vec, decimals=3)
151
- # st.write(f"{vectors_2d.dtype} {vec.dtype}")
152
- # if closest_sentence_index > -1:
153
- # st.write(df["title"].iloc[closest_sentence_index])
154
- # # Card content
155
- # st.markdown("## Card Container")
156
- # st.write("This is an additional card container to the right of the main content.")
157
- # st.write("You can use this space to show additional information, actions, or insights.")
158
- # st.button("Card Button")
 
1
+ import os
2
+ import pickle
 
 
 
3
  import streamlit as st
4
+ import pandas as pd
5
+ import vec2text
6
+ import torch
7
+ from transformers import AutoModel, AutoTokenizer
8
+ from umap import UMAP
9
+ from tqdm import tqdm
10
+ import plotly.express as px
11
+ import numpy as np
12
+ from sklearn.decomposition import PCA
13
+ from streamlit_plotly_events import plotly_events
14
+ import plotly.graph_objects as go
15
+ import logging
16
+ import utils
17
+ # Activate tqdm with pandas
18
+ tqdm.pandas()
19
+
20
+ # Custom file cache decorator
21
+ def file_cache(file_path):
22
+ def decorator(func):
23
+ def wrapper(*args, **kwargs):
24
+ # Check if the file already exists
25
+ if os.path.exists(file_path):
26
+ # Load from cache
27
+ with open(file_path, "rb") as f:
28
+ print(f"Loading cached data from {file_path}")
29
+ return pickle.load(f)
30
+ else:
31
+ # Compute and save to cache
32
+ result = func(*args, **kwargs)
33
+ with open(file_path, "wb") as f:
34
+ pickle.dump(result, f)
35
+ print(f"Saving new cache to {file_path}")
36
+ return result
37
+ return wrapper
38
+ return decorator
39
+
40
+ @st.cache_resource
41
+ def vector_compressor_from_config():
42
+ # Return UMAP with 2 components for dimensionality reduction
43
+ # return UMAP(n_components=2)
44
+ return PCA(n_components=2)
45
+
46
+
47
+ # Caching the dataframe since loading from an external source can be time-consuming
48
+ @st.cache_data
49
+ def load_data():
50
+ return pd.read_csv("https://huggingface.co/datasets/marksverdhei/reddit-syac-urls/resolve/main/train.csv")
51
+
52
+ df = load_data()
53
+
54
+ # Caching the model and tokenizer to avoid reloading
55
+ @st.cache_resource
56
+ def load_model_and_tokenizer():
57
+ encoder = AutoModel.from_pretrained("sentence-transformers/gtr-t5-base").encoder.to("cuda")
58
+ tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/gtr-t5-base")
59
+ return encoder, tokenizer
60
+
61
+ encoder, tokenizer = load_model_and_tokenizer()
62
+
63
+ # Caching the vec2text corrector
64
+ @st.cache_resource
65
+ def load_corrector():
66
+ return vec2text.load_pretrained_corrector("gtr-base")
67
+
68
+ corrector = load_corrector()
69
+
70
+ # Caching the precomputed embeddings since they are stored locally and large
71
+ @st.cache_data
72
+ def load_embeddings():
73
+ return np.load("syac-title-embeddings.npy")
74
+
75
+ embeddings = load_embeddings()
76
+
77
+ # Custom cache the UMAP reduction using file_cache decorator
78
+ @st.cache_data
79
+ @file_cache(".cache/reducer_embeddings.pickle")
80
+ def reduce_embeddings(embeddings):
81
+ reducer = vector_compressor_from_config()
82
+ return reducer.fit_transform(embeddings), reducer
83
+
84
+ vectors_2d, reducer = reduce_embeddings(embeddings)
85
+
86
+ # Add a scatter plot using Plotly
87
+ fig = px.scatter(
88
+ x=vectors_2d[:, 0],
89
+ y=vectors_2d[:, 1],
90
+ opacity=0.6,
91
+ hover_data={"Title": df["title"]},
92
+ labels={'x': 'UMAP Dimension 1', 'y': 'UMAP Dimension 2'},
93
+ title="UMAP Scatter Plot of Reddit Titles",
94
+ color_discrete_sequence=["#ff504c"] # Set default blue color for points
95
+ )
96
+
97
+ # Customize the layout to adapt to browser settings (light/dark mode)
98
+ fig.update_layout(
99
+ template=None, # Let Plotly adapt automatically based on user settings
100
+ plot_bgcolor="rgba(0, 0, 0, 0)",
101
+ paper_bgcolor="rgba(0, 0, 0, 0)"
102
+ )
103
+
104
+ x, y = 0.0, 0.0
105
+ vec = np.array([x, y]).astype("float32")
106
+
107
+ # Add a card container to the right of the content with Streamlit columns
108
+ col1, col2 = st.columns([3, 1]) # Adjusting ratio to allocate space for the card container
109
+
110
+ with col1:
111
+ # Main content stays here (scatterplot, form, etc.)
112
+ selected_points = plotly_events(fig, click_event=True, hover_event=False, #override_height=600, override_width="100%"
113
+ )
114
+ with st.form(key="form1_main"):
115
+ if selected_points:
116
+ clicked_point = selected_points[0]
117
+ x_coord = x = clicked_point['x']
118
+ y_coord = y = clicked_point['y']
119
+
120
+ x = st.number_input("X Coordinate", value=x, format="%.10f")
121
+ y = st.number_input("Y Coordinate", value=y, format="%.10f")
122
+ vec = np.array([x, y]).astype("float32")
123
+
124
+
125
+ submit_button = st.form_submit_button("Submit")
126
+
127
+ if selected_points or submit_button:
128
+ inferred_embedding = reducer.inverse_transform(np.array([[x, y]]) if not isinstance(reducer, UMAP) else np.array([[x, y]]))
129
+ inferred_embedding = inferred_embedding.astype("float32")
130
+
131
+ output = vec2text.invert_embeddings(
132
+ embeddings=torch.tensor(inferred_embedding).cuda(),
133
+ corrector=corrector,
134
+ num_steps=20,
135
+ )
136
+
137
+ st.text(str(output))
138
+ st.text(str(inferred_embedding))
139
+ else:
140
+ st.text("Click on a point in the scatterplot to see its coordinates.")
141
+
142
+ with col2:
143
+ closest_sentence_index = utils.find_exact_match(vectors_2d, vec, decimals=3)
144
+ st.write(f"{vectors_2d.dtype} {vec.dtype}")
145
+ if closest_sentence_index > -1:
146
+ st.write(df["title"].iloc[closest_sentence_index])
147
+ # Card content
148
+ st.markdown("## Card Container")
149
+ st.write("This is an additional card container to the right of the main content.")
150
+ st.write("You can use this space to show additional information, actions, or insights.")
151
+ st.button("Card Button")