marksverdhei commited on
Commit
8246b3c
·
1 Parent(s): 2e98766
Files changed (1) hide show
  1. app.py +51 -49
app.py CHANGED
@@ -13,7 +13,7 @@ from sklearn.decomposition import PCA
13
  from streamlit_plotly_events import plotly_events
14
  import plotly.graph_objects as go
15
  import logging
16
-
17
  # Activate tqdm with pandas
18
  tqdm.pandas()
19
 
@@ -40,7 +40,9 @@ def file_cache(file_path):
40
  @st.cache_resource
41
  def vector_compressor_from_config():
42
  # Return UMAP with 2 components for dimensionality reduction
43
- return UMAP(n_components=2)
 
 
44
 
45
  # Caching the dataframe since loading from an external source can be time-consuming
46
  @st.cache_data
@@ -100,50 +102,50 @@ fig.update_layout(
100
  )
101
 
102
  x, y = 0.0, 0.0
103
-
104
- # Display the scatterplot and capture click events
105
- selected_points = plotly_events(fig, click_event=True, hover_event=False, override_height=600, override_width="100%")
106
-
107
- # Sidebar for additional information
108
- st.sidebar.header("Scatter Plot Info")
109
- st.sidebar.write("""
110
- This scatter plot visualizes the UMAP dimensionality reduction of Reddit post titles.
111
- Each point represents a post, with similar titles being positioned closer together.
112
- """)
113
- st.sidebar.write("Use the form below to select coordinates or click on a point in the scatter plot.")
114
- st.sidebar.markdown("---")
115
-
116
- st.sidebar.header("How to Use")
117
- st.sidebar.write("""
118
- 1. **Click a point** in the scatter plot to see the corresponding coordinates.
119
- 2. **Adjust the coordinates** using the form inputs if needed.
120
- 3. **Submit** to see the reconstructed text output.
121
- """)
122
-
123
- # Form for inputting coordinates
124
- with st.form(key="form1"):
125
- # If a point is clicked, handle the embedding inversion
126
- if selected_points:
127
- clicked_point = selected_points[0]
128
- x_coord = x = clicked_point['x']
129
- y_coord = y = clicked_point['y']
130
-
131
- x = st.number_input("X Coordinate", value=x, format="%.10f")
132
- y = st.number_input("Y Coordinate", value=y, format="%.10f")
133
-
134
- submit_button = st.form_submit_button("Submit")
135
-
136
- if selected_points or submit_button:
137
- inferred_embedding = reducer.inverse_transform(np.array([[x, y]]) if not isinstance(reducer, UMAP) else np.array([[x, y]]))
138
- inferred_embedding = inferred_embedding.astype("float32")
139
-
140
- output = vec2text.invert_embeddings(
141
- embeddings=torch.tensor(inferred_embedding).cuda(),
142
- corrector=corrector,
143
- num_steps=20,
144
- )
145
-
146
- st.text(str(output))
147
- st.text(str(inferred_embedding))
148
- else:
149
- st.text("Click on a point in the scatterplot to see its coordinates.")
 
13
  from streamlit_plotly_events import plotly_events
14
  import plotly.graph_objects as go
15
  import logging
16
+ import utils
17
  # Activate tqdm with pandas
18
  tqdm.pandas()
19
 
 
40
  @st.cache_resource
41
  def vector_compressor_from_config():
42
  # Return UMAP with 2 components for dimensionality reduction
43
+ # return UMAP(n_components=2)
44
+ return PCA(n_components=2)
45
+
46
 
47
  # Caching the dataframe since loading from an external source can be time-consuming
48
  @st.cache_data
 
102
  )
103
 
104
  x, y = 0.0, 0.0
105
+ vec = np.array([x, y]).astype("float32")
106
+
107
+ # Add a card container to the right of the content with Streamlit columns
108
+ col1, col2 = st.columns([3, 1]) # Adjusting ratio to allocate space for the card container
109
+
110
+ with col1:
111
+ # Main content stays here (scatterplot, form, etc.)
112
+ selected_points = plotly_events(fig, click_event=True, hover_event=False, #override_height=600, override_width="100%"
113
+ )
114
+ with st.form(key="form1_main"):
115
+ if selected_points:
116
+ clicked_point = selected_points[0]
117
+ x_coord = x = clicked_point['x']
118
+ y_coord = y = clicked_point['y']
119
+
120
+ x = st.number_input("X Coordinate", value=x, format="%.10f")
121
+ y = st.number_input("Y Coordinate", value=y, format="%.10f")
122
+ vec = np.array([x, y]).astype("float32")
123
+
124
+
125
+ submit_button = st.form_submit_button("Submit")
126
+
127
+ if selected_points or submit_button:
128
+ inferred_embedding = reducer.inverse_transform(np.array([[x, y]]) if not isinstance(reducer, UMAP) else np.array([[x, y]]))
129
+ inferred_embedding = inferred_embedding.astype("float32")
130
+
131
+ output = vec2text.invert_embeddings(
132
+ embeddings=torch.tensor(inferred_embedding).cuda(),
133
+ corrector=corrector,
134
+ num_steps=20,
135
+ )
136
+
137
+ st.text(str(output))
138
+ st.text(str(inferred_embedding))
139
+ else:
140
+ st.text("Click on a point in the scatterplot to see its coordinates.")
141
+
142
+ with col2:
143
+ closest_sentence_index = utils.find_exact_match(vectors_2d, vec, decimals=3)
144
+ st.write(f"{vectors_2d.dtype} {vec.dtype}")
145
+ if closest_sentence_index > -1:
146
+ st.write(df["title"].iloc[closest_sentence_index])
147
+ # Card content
148
+ st.markdown("## Card Container")
149
+ st.write("This is an additional card container to the right of the main content.")
150
+ st.write("You can use this space to show additional information, actions, or insights.")
151
+ st.button("Card Button")