marksverdhei commited on
Commit
b21feb2
·
1 Parent(s): c0f5faa

Run the simplest test

Browse files
Files changed (1) hide show
  1. app.py +157 -151
app.py CHANGED
@@ -1,152 +1,158 @@
1
- import os
2
- import pickle
 
 
 
3
  import streamlit as st
4
- import pandas as pd
5
- import vec2text
6
- import torch
7
- from transformers import AutoModel, AutoTokenizer
8
- from umap import UMAP
9
- from tqdm import tqdm
10
- import plotly.express as px
11
- import numpy as np
12
- from sklearn.decomposition import PCA
13
- # from streamlit_plotly_events import plotly_events
14
- import plotly.graph_objects as go
15
- import logging
16
- import utils
17
- # Activate tqdm with pandas
18
- tqdm.pandas()
19
-
20
- # Custom file cache decorator
21
- def file_cache(file_path):
22
- def decorator(func):
23
- def wrapper(*args, **kwargs):
24
- # Check if the file already exists
25
- if os.path.exists(file_path):
26
- # Load from cache
27
- with open(file_path, "rb") as f:
28
- print(f"Loading cached data from {file_path}")
29
- return pickle.load(f)
30
- else:
31
- # Compute and save to cache
32
- result = func(*args, **kwargs)
33
- with open(file_path, "wb") as f:
34
- pickle.dump(result, f)
35
- print(f"Saving new cache to {file_path}")
36
- return result
37
- return wrapper
38
- return decorator
39
-
40
- @st.cache_resource
41
- def vector_compressor_from_config():
42
- # Return UMAP with 2 components for dimensionality reduction
43
- # return UMAP(n_components=2)
44
- return PCA(n_components=2)
45
-
46
-
47
- # Caching the dataframe since loading from an external source can be time-consuming
48
- @st.cache_data
49
- def load_data():
50
- return pd.read_csv("https://huggingface.co/datasets/marksverdhei/reddit-syac-urls/resolve/main/train.csv")
51
-
52
- df = load_data()
53
-
54
- # Caching the model and tokenizer to avoid reloading
55
- @st.cache_resource
56
- def load_model_and_tokenizer():
57
- encoder = AutoModel.from_pretrained("sentence-transformers/gtr-t5-base").encoder.to("cuda")
58
- tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/gtr-t5-base")
59
- return encoder, tokenizer
60
-
61
- encoder, tokenizer = load_model_and_tokenizer()
62
-
63
- # Caching the vec2text corrector
64
- @st.cache_resource
65
- def load_corrector():
66
- return vec2text.load_pretrained_corrector("gtr-base")
67
-
68
- corrector = load_corrector()
69
-
70
- # Caching the precomputed embeddings since they are stored locally and large
71
- @st.cache_data
72
- def load_embeddings():
73
- return np.load("syac-title-embeddings.npy")
74
-
75
- embeddings = load_embeddings()
76
-
77
- # Custom cache the UMAP reduction using file_cache decorator
78
- @st.cache_data
79
- @file_cache(".cache/reducer_embeddings.pickle")
80
- def reduce_embeddings(embeddings):
81
- reducer = vector_compressor_from_config()
82
- return reducer.fit_transform(embeddings), reducer
83
-
84
- vectors_2d, reducer = reduce_embeddings(embeddings)
85
-
86
- # Add a scatter plot using Plotly
87
- fig = px.scatter(
88
- x=vectors_2d[:, 0],
89
- y=vectors_2d[:, 1],
90
- opacity=0.6,
91
- hover_data={"Title": df["title"]},
92
- labels={'x': 'UMAP Dimension 1', 'y': 'UMAP Dimension 2'},
93
- title="UMAP Scatter Plot of Reddit Titles",
94
- color_discrete_sequence=["#ff504c"] # Set default blue color for points
95
- )
96
-
97
- # Customize the layout to adapt to browser settings (light/dark mode)
98
- fig.update_layout(
99
- template=None, # Let Plotly adapt automatically based on user settings
100
- plot_bgcolor="rgba(0, 0, 0, 0)",
101
- paper_bgcolor="rgba(0, 0, 0, 0)"
102
- )
103
-
104
- x, y = 0.0, 0.0
105
- vec = np.array([x, y]).astype("float32")
106
-
107
- # Add a card container to the right of the content with Streamlit columns
108
- col1, col2 = st.columns([3, 1]) # Adjusting ratio to allocate space for the card container
109
-
110
- with col1:
111
- # Main content stays here (scatterplot, form, etc.)
112
- # selected_points = plotly_events(fig, click_event=True, hover_event=False,
113
- # )
114
- selected_points = None
115
- with st.form(key="form1_main"):
116
- if selected_points:
117
- clicked_point = selected_points[0]
118
- x_coord = x = clicked_point['x']
119
- y_coord = y = clicked_point['y']
120
-
121
- x = st.number_input("X Coordinate", value=x, format="%.10f")
122
- y = st.number_input("Y Coordinate", value=y, format="%.10f")
123
- vec = np.array([x, y]).astype("float32")
124
-
125
-
126
- submit_button = st.form_submit_button("Submit")
127
-
128
- if selected_points or submit_button:
129
- inferred_embedding = reducer.inverse_transform(np.array([[x, y]]) if not isinstance(reducer, UMAP) else np.array([[x, y]]))
130
- inferred_embedding = inferred_embedding.astype("float32")
131
-
132
- output = vec2text.invert_embeddings(
133
- embeddings=torch.tensor(inferred_embedding).cuda(),
134
- corrector=corrector,
135
- num_steps=20,
136
- )
137
-
138
- st.text(str(output))
139
- st.text(str(inferred_embedding))
140
- else:
141
- st.text("Click on a point in the scatterplot to see its coordinates.")
142
-
143
- with col2:
144
- closest_sentence_index = utils.find_exact_match(vectors_2d, vec, decimals=3)
145
- st.write(f"{vectors_2d.dtype} {vec.dtype}")
146
- if closest_sentence_index > -1:
147
- st.write(df["title"].iloc[closest_sentence_index])
148
- # Card content
149
- st.markdown("## Card Container")
150
- st.write("This is an additional card container to the right of the main content.")
151
- st.write("You can use this space to show additional information, actions, or insights.")
152
- st.button("Card Button")
 
 
 
 
1
+
2
+
3
+
4
+ # import os
5
+ # import pickle
6
  import streamlit as st
7
+
8
+
9
+ st.text("This is a test")
10
+ # import pandas as pd
11
+ # import vec2text
12
+ # import torch
13
+ # from transformers import AutoModel, AutoTokenizer
14
+ # from umap import UMAP
15
+ # from tqdm import tqdm
16
+ # import plotly.express as px
17
+ # import numpy as np
18
+ # from sklearn.decomposition import PCA
19
+ # # from streamlit_plotly_events import plotly_events
20
+ # import plotly.graph_objects as go
21
+ # import logging
22
+ # import utils
23
+ # # Activate tqdm with pandas
24
+ # tqdm.pandas()
25
+
26
+ # # Custom file cache decorator
27
+ # def file_cache(file_path):
28
+ # def decorator(func):
29
+ # def wrapper(*args, **kwargs):
30
+ # # Check if the file already exists
31
+ # if os.path.exists(file_path):
32
+ # # Load from cache
33
+ # with open(file_path, "rb") as f:
34
+ # print(f"Loading cached data from {file_path}")
35
+ # return pickle.load(f)
36
+ # else:
37
+ # # Compute and save to cache
38
+ # result = func(*args, **kwargs)
39
+ # with open(file_path, "wb") as f:
40
+ # pickle.dump(result, f)
41
+ # print(f"Saving new cache to {file_path}")
42
+ # return result
43
+ # return wrapper
44
+ # return decorator
45
+
46
+ # @st.cache_resource
47
+ # def vector_compressor_from_config():
48
+ # # Return UMAP with 2 components for dimensionality reduction
49
+ # # return UMAP(n_components=2)
50
+ # return PCA(n_components=2)
51
+
52
+
53
+ # # Caching the dataframe since loading from an external source can be time-consuming
54
+ # @st.cache_data
55
+ # def load_data():
56
+ # return pd.read_csv("https://huggingface.co/datasets/marksverdhei/reddit-syac-urls/resolve/main/train.csv")
57
+
58
+ # df = load_data()
59
+
60
+ # # Caching the model and tokenizer to avoid reloading
61
+ # # @st.cache_resource
62
+ # # def load_model_and_tokenizer():
63
+ # # encoder = AutoModel.from_pretrained("sentence-transformers/gtr-t5-base").encoder.to("cuda")
64
+ # # tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/gtr-t5-base")
65
+ # # return encoder, tokenizer
66
+
67
+ # # encoder, tokenizer = load_model_and_tokenizer()
68
+
69
+ # # Caching the vec2text corrector
70
+ # # @st.cache_resource
71
+ # # def load_corrector():
72
+ # # return vec2text.load_pretrained_corrector("gtr-base")
73
+
74
+ # # corrector = load_corrector()
75
+
76
+ # # Caching the precomputed embeddings since they are stored locally and large
77
+ # @st.cache_data
78
+ # def load_embeddings():
79
+ # return np.load("syac-title-embeddings.npy")
80
+
81
+ # embeddings = load_embeddings()
82
+
83
+ # # Custom cache the UMAP reduction using file_cache decorator
84
+ # @st.cache_data
85
+ # @file_cache(".cache/reducer_embeddings.pickle")
86
+ # def reduce_embeddings(embeddings):
87
+ # reducer = vector_compressor_from_config()
88
+ # return reducer.fit_transform(embeddings), reducer
89
+
90
+ # vectors_2d, reducer = reduce_embeddings(embeddings)
91
+
92
+ # # Add a scatter plot using Plotly
93
+ # # fig = px.scatter(
94
+ # # x=vectors_2d[:, 0],
95
+ # # y=vectors_2d[:, 1],
96
+ # # opacity=0.6,
97
+ # # hover_data={"Title": df["title"]},
98
+ # # labels={'x': 'UMAP Dimension 1', 'y': 'UMAP Dimension 2'},
99
+ # # title="UMAP Scatter Plot of Reddit Titles",
100
+ # # color_discrete_sequence=["#ff504c"] # Set default blue color for points
101
+ # # )
102
+
103
+ # # # Customize the layout to adapt to browser settings (light/dark mode)
104
+ # # fig.update_layout(
105
+ # # template=None, # Let Plotly adapt automatically based on user settings
106
+ # # plot_bgcolor="rgba(0, 0, 0, 0)",
107
+ # # paper_bgcolor="rgba(0, 0, 0, 0)"
108
+ # # )
109
+
110
+ # x, y = 0.0, 0.0
111
+ # vec = np.array([x, y]).astype("float32")
112
+
113
+ # # Add a card container to the right of the content with Streamlit columns
114
+ # col1, col2 = st.columns([3, 1]) # Adjusting ratio to allocate space for the card container
115
+
116
+ # with col1:
117
+ # # Main content stays here (scatterplot, form, etc.)
118
+ # # selected_points = plotly_events(fig, click_event=True, hover_event=False,
119
+ # # )
120
+ # selected_points = None
121
+ # with st.form(key="form1_main"):
122
+ # if selected_points:
123
+ # clicked_point = selected_points[0]
124
+ # x_coord = x = clicked_point['x']
125
+ # y_coord = y = clicked_point['y']
126
+
127
+ # x = st.number_input("X Coordinate", value=x, format="%.10f")
128
+ # y = st.number_input("Y Coordinate", value=y, format="%.10f")
129
+ # vec = np.array([x, y]).astype("float32")
130
+
131
+
132
+ # submit_button = st.form_submit_button("Submit")
133
+
134
+ # if selected_points or submit_button:
135
+ # inferred_embedding = reducer.inverse_transform(np.array([[x, y]]) if not isinstance(reducer, UMAP) else np.array([[x, y]]))
136
+ # inferred_embedding = inferred_embedding.astype("float32")
137
+
138
+ # output = vec2text.invert_embeddings(
139
+ # embeddings=torch.tensor(inferred_embedding).cuda(),
140
+ # corrector=corrector,
141
+ # num_steps=20,
142
+ # )
143
+
144
+ # st.text(str(output))
145
+ # st.text(str(inferred_embedding))
146
+ # else:
147
+ # st.text("Click on a point in the scatterplot to see its coordinates.")
148
+
149
+ # with col2:
150
+ # closest_sentence_index = utils.find_exact_match(vectors_2d, vec, decimals=3)
151
+ # st.write(f"{vectors_2d.dtype} {vec.dtype}")
152
+ # if closest_sentence_index > -1:
153
+ # st.write(df["title"].iloc[closest_sentence_index])
154
+ # # Card content
155
+ # st.markdown("## Card Container")
156
+ # st.write("This is an additional card container to the right of the main content.")
157
+ # st.write("You can use this space to show additional information, actions, or insights.")
158
+ # st.button("Card Button")