Commit
·
93f5069
1
Parent(s):
8aa44e7
Update program
Browse files
app.py
CHANGED
@@ -7,10 +7,16 @@ from umap import UMAP
|
|
7 |
from tqdm import tqdm
|
8 |
import plotly.express as px
|
9 |
import numpy as np
|
10 |
-
|
11 |
# Activate tqdm with pandas
|
12 |
tqdm.pandas()
|
13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
# Caching the dataframe since loading from external source can be time-consuming
|
15 |
@st.cache_data
|
16 |
def load_data():
|
@@ -42,9 +48,10 @@ def load_embeddings():
|
|
42 |
embeddings = load_embeddings()
|
43 |
|
44 |
# Caching UMAP reduction as it's a heavy computation
|
|
|
45 |
@st.cache_resource
|
46 |
def reduce_embeddings(embeddings):
|
47 |
-
reducer =
|
48 |
return reducer.fit_transform(embeddings), reducer
|
49 |
|
50 |
vectors_2d, reducer = reduce_embeddings(embeddings)
|
@@ -69,7 +76,9 @@ with st.form(key="form1"):
|
|
69 |
submit_button = st.form_submit_button("Submit")
|
70 |
|
71 |
if submit_button:
|
72 |
-
inferred_embedding = reducer.inverse_transform([[x, y]])
|
|
|
|
|
73 |
output = vec2text.invert_embeddings(
|
74 |
embeddings=torch.tensor(inferred_embedding).cuda(),
|
75 |
corrector=corrector,
|
|
|
7 |
from tqdm import tqdm
|
8 |
import plotly.express as px
|
9 |
import numpy as np
|
10 |
+
from sklearn.decomposition import PCA
|
11 |
# Activate tqdm with pandas
|
12 |
tqdm.pandas()
|
13 |
|
14 |
+
@st.cache_resource
|
15 |
+
def vector_compressor_from_config():
|
16 |
+
'TODO'
|
17 |
+
# return PCA()
|
18 |
+
return UMAP()
|
19 |
+
|
20 |
# Caching the dataframe since loading from external source can be time-consuming
|
21 |
@st.cache_data
|
22 |
def load_data():
|
|
|
48 |
embeddings = load_embeddings()
|
49 |
|
50 |
# Caching UMAP reduction as it's a heavy computation
|
51 |
+
|
52 |
@st.cache_resource
|
53 |
def reduce_embeddings(embeddings):
|
54 |
+
reducer = vector_compressor_from_config()
|
55 |
return reducer.fit_transform(embeddings), reducer
|
56 |
|
57 |
vectors_2d, reducer = reduce_embeddings(embeddings)
|
|
|
76 |
submit_button = st.form_submit_button("Submit")
|
77 |
|
78 |
if submit_button:
|
79 |
+
inferred_embedding = reducer.inverse_transform(np.array([x, y]).T if not isinstance(reducer, UMAP) else np.array([[x, y]]))
|
80 |
+
inferred_embedding.astype("float32")
|
81 |
+
print(inferred_embedding.dtype)
|
82 |
output = vec2text.invert_embeddings(
|
83 |
embeddings=torch.tensor(inferred_embedding).cuda(),
|
84 |
corrector=corrector,
|