marksverdhei commited on
Commit
644a030
·
1 Parent(s): 93f5069

It works now

Browse files
Files changed (1) hide show
  1. app.py +50 -26
app.py CHANGED
@@ -2,20 +2,23 @@ import streamlit as st
2
  import pandas as pd
3
  import vec2text
4
  import torch
5
- from transformers import AutoModel, AutoTokenizer, PreTrainedTokenizer, PreTrainedModel
6
  from umap import UMAP
7
  from tqdm import tqdm
8
  import plotly.express as px
9
  import numpy as np
10
  from sklearn.decomposition import PCA
 
 
 
11
  # Activate tqdm with pandas
12
  tqdm.pandas()
13
 
14
  @st.cache_resource
15
  def vector_compressor_from_config():
16
  'TODO'
17
- # return PCA()
18
- return UMAP()
19
 
20
  # Caching the dataframe since loading from external source can be time-consuming
21
  @st.cache_data
@@ -48,7 +51,6 @@ def load_embeddings():
48
  embeddings = load_embeddings()
49
 
50
  # Caching UMAP reduction as it's a heavy computation
51
-
52
  @st.cache_resource
53
  def reduce_embeddings(embeddings):
54
  reducer = vector_compressor_from_config()
@@ -60,29 +62,51 @@ vectors_2d, reducer = reduce_embeddings(embeddings)
60
  fig = px.scatter(
61
  x=vectors_2d[:, 0],
62
  y=vectors_2d[:, 1],
63
- opacity=0.4,
64
  hover_data={"Title": df["title"]},
65
  labels={'x': 'UMAP Dimension 1', 'y': 'UMAP Dimension 2'},
66
- title="UMAP Scatter Plot of Reddit Titles"
 
 
 
 
 
 
 
 
67
  )
68
 
69
- # Display plot in Streamlit
70
- st.plotly_chart(fig)
71
-
72
- # Streamlit form to take user inputs and handle interaction
73
- with st.form(key="form1"):
74
- x = st.number_input("Input X coordinate")
75
- y = st.number_input("Input Y coordinate")
76
- submit_button = st.form_submit_button("Submit")
77
-
78
- if submit_button:
79
- inferred_embedding = reducer.inverse_transform(np.array([x, y]).T if not isinstance(reducer, UMAP) else np.array([[x, y]]))
80
- inferred_embedding.astype("float32")
81
- print(inferred_embedding.dtype)
82
- output = vec2text.invert_embeddings(
83
- embeddings=torch.tensor(inferred_embedding).cuda(),
84
- corrector=corrector,
85
- num_steps=20,
86
- )
87
- st.text(str(output))
88
- st.text(str(inferred_embedding))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import pandas as pd
3
  import vec2text
4
  import torch
5
+ from transformers import AutoModel, AutoTokenizer
6
  from umap import UMAP
7
  from tqdm import tqdm
8
  import plotly.express as px
9
  import numpy as np
10
  from sklearn.decomposition import PCA
11
+ from streamlit_plotly_events import plotly_events
12
+ import plotly.graph_objects as go
13
+ import logging
14
  # Activate tqdm with pandas
15
  tqdm.pandas()
16
 
17
  @st.cache_resource
18
  def vector_compressor_from_config():
19
  'TODO'
20
+ # return PCA(2)
21
+ return UMAP(2)
22
 
23
  # Caching the dataframe since loading from external source can be time-consuming
24
  @st.cache_data
 
51
  embeddings = load_embeddings()
52
 
53
  # Caching UMAP reduction as it's a heavy computation
 
54
  @st.cache_resource
55
  def reduce_embeddings(embeddings):
56
  reducer = vector_compressor_from_config()
 
62
  fig = px.scatter(
63
  x=vectors_2d[:, 0],
64
  y=vectors_2d[:, 1],
65
+ opacity=0.6,
66
  hover_data={"Title": df["title"]},
67
  labels={'x': 'UMAP Dimension 1', 'y': 'UMAP Dimension 2'},
68
+ title="UMAP Scatter Plot of Reddit Titles",
69
+ color_discrete_sequence=["#01a8d3"] # Set default blue color for points
70
+ )
71
+
72
+ # Customize the layout to adapt to browser settings (light/dark mode)
73
+ fig.update_layout(
74
+ template=None, # Let Plotly adapt automatically based on user settings
75
+ plot_bgcolor="rgba(0, 0, 0, 0)",
76
+ paper_bgcolor="rgba(0, 0, 0, 0)"
77
  )
78
 
79
+ # Display the scatterplot and capture click events
80
+ selected_points = plotly_events(fig, click_event=True, hover_event=False, override_height=600, override_width="100%")
81
+
82
+
83
+ # If a point is clicked, handle the embedding inversion
84
+ if selected_points:
85
+
86
+ clicked_point = selected_points[0]
87
+ x_coord = x = clicked_point['x']
88
+ y_coord = y = clicked_point['y']
89
+ st.text(f"Embeddings shape: {embeddings.shape}")
90
+ st.text(f"2dvector shapes shape: {vectors_2d.shape}")
91
+ st.text(f"Clicked point coordinates: x = {x_coord}, y = {y_coord}")
92
+ st.text("fOO")
93
+ logging.info("Foo")
94
+ inferred_embedding = reducer.inverse_transform(np.array([[x, y]]) if not isinstance(reducer, UMAP) else np.array([[x, y]]))
95
+ logging.info("Bar")
96
+
97
+ st.text("Bar")
98
+
99
+ inferred_embedding = inferred_embedding.astype("float32")
100
+ st.text("Bar")
101
+
102
+ output = vec2text.invert_embeddings(
103
+ embeddings=torch.tensor(inferred_embedding).cuda(),
104
+ corrector=corrector,
105
+ num_steps=20,
106
+ )
107
+ st.text("Bar")
108
+
109
+ st.text(str(output))
110
+ st.text(str(inferred_embedding))
111
+ else:
112
+ st.text("Click on a point in the scatterplot to see its coordinates.")