marksverdhei commited on
Commit
93f5069
·
1 Parent(s): 8aa44e7

Update program

Browse files
Files changed (1) hide show
  1. app.py +12 -3
app.py CHANGED
@@ -7,10 +7,16 @@ from umap import UMAP
7
  from tqdm import tqdm
8
  import plotly.express as px
9
  import numpy as np
10
-
11
  # Activate tqdm with pandas
12
  tqdm.pandas()
13
 
 
 
 
 
 
 
14
  # Caching the dataframe since loading from external source can be time-consuming
15
  @st.cache_data
16
  def load_data():
@@ -42,9 +48,10 @@ def load_embeddings():
42
  embeddings = load_embeddings()
43
 
44
  # Caching UMAP reduction as it's a heavy computation
 
45
  @st.cache_resource
46
  def reduce_embeddings(embeddings):
47
- reducer = UMAP()
48
  return reducer.fit_transform(embeddings), reducer
49
 
50
  vectors_2d, reducer = reduce_embeddings(embeddings)
@@ -69,7 +76,9 @@ with st.form(key="form1"):
69
  submit_button = st.form_submit_button("Submit")
70
 
71
  if submit_button:
72
- inferred_embedding = reducer.inverse_transform([[x, y]])
 
 
73
  output = vec2text.invert_embeddings(
74
  embeddings=torch.tensor(inferred_embedding).cuda(),
75
  corrector=corrector,
 
7
  from tqdm import tqdm
8
  import plotly.express as px
9
  import numpy as np
10
+ from sklearn.decomposition import PCA
11
  # Activate tqdm with pandas
12
  tqdm.pandas()
13
 
14
+ @st.cache_resource
15
+ def vector_compressor_from_config():
16
+ 'TODO'
17
+ # return PCA()
18
+ return UMAP()
19
+
20
  # Caching the dataframe since loading from external source can be time-consuming
21
  @st.cache_data
22
  def load_data():
 
48
  embeddings = load_embeddings()
49
 
50
  # Caching UMAP reduction as it's a heavy computation
51
+
52
  @st.cache_resource
53
  def reduce_embeddings(embeddings):
54
+ reducer = vector_compressor_from_config()
55
  return reducer.fit_transform(embeddings), reducer
56
 
57
  vectors_2d, reducer = reduce_embeddings(embeddings)
 
76
  submit_button = st.form_submit_button("Submit")
77
 
78
  if submit_button:
79
+ inferred_embedding = reducer.inverse_transform(np.array([x, y]).T if not isinstance(reducer, UMAP) else np.array([[x, y]]))
80
+ inferred_embedding.astype("float32")
81
+ print(inferred_embedding.dtype)
82
  output = vec2text.invert_embeddings(
83
  embeddings=torch.tensor(inferred_embedding).cuda(),
84
  corrector=corrector,