marksverdhei commited on
Commit
2b22bff
·
1 Parent(s): 2063af3

Add file cache

Browse files
Files changed (1) hide show
  1. app.py +29 -9
app.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  import streamlit as st
2
  import pandas as pd
3
  import vec2text
@@ -11,16 +13,36 @@ from sklearn.decomposition import PCA
11
  from streamlit_plotly_events import plotly_events
12
  import plotly.graph_objects as go
13
  import logging
 
14
  # Activate tqdm with pandas
15
  tqdm.pandas()
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  @st.cache_resource
18
  def vector_compressor_from_config():
19
- 'TODO'
20
- # return PCA(n:n_components=2)
21
  return UMAP(n_components=2)
22
 
23
- # Caching the dataframe since loading from external source can be time-consuming
24
  @st.cache_data
25
  def load_data():
26
  return pd.read_csv("https://huggingface.co/datasets/marksverdhei/reddit-syac-urls/resolve/main/train.csv")
@@ -50,8 +72,9 @@ def load_embeddings():
50
 
51
  embeddings = load_embeddings()
52
 
53
- # Caching UMAP reduction as it's a heavy computation
54
- @st.cache_resource
 
55
  def reduce_embeddings(embeddings):
56
  reducer = vector_compressor_from_config()
57
  return reducer.fit_transform(embeddings), reducer
@@ -79,15 +102,12 @@ fig.update_layout(
79
  # Display the scatterplot and capture click events
80
  selected_points = plotly_events(fig, click_event=True, hover_event=False, override_height=600, override_width="100%")
81
 
82
-
83
  # If a point is clicked, handle the embedding inversion
84
  if selected_points:
85
-
86
  clicked_point = selected_points[0]
87
  x_coord = x = clicked_point['x']
88
  y_coord = y = clicked_point['y']
89
 
90
-
91
  inferred_embedding = reducer.inverse_transform(np.array([[x, y]]) if not isinstance(reducer, UMAP) else np.array([[x, y]]))
92
 
93
  inferred_embedding = inferred_embedding.astype("float32")
@@ -101,4 +121,4 @@ if selected_points:
101
  st.text(str(output))
102
  st.text(str(inferred_embedding))
103
  else:
104
- st.text("Click on a point in the scatterplot to see its coordinates.")
 
1
+ import os
2
+ import pickle
3
  import streamlit as st
4
  import pandas as pd
5
  import vec2text
 
13
  from streamlit_plotly_events import plotly_events
14
  import plotly.graph_objects as go
15
  import logging
16
+
17
  # Activate tqdm with pandas
18
  tqdm.pandas()
19
 
20
+ # Custom file cache decorator
21
+ def file_cache(file_path):
22
+ def decorator(func):
23
+ def wrapper(*args, **kwargs):
24
+ # Check if the file already exists
25
+ if os.path.exists(file_path):
26
+ # Load from cache
27
+ with open(file_path, "rb") as f:
28
+ print(f"Loading cached data from {file_path}")
29
+ return pickle.load(f)
30
+ else:
31
+ # Compute and save to cache
32
+ result = func(*args, **kwargs)
33
+ with open(file_path, "wb") as f:
34
+ pickle.dump(result, f)
35
+ print(f"Saving new cache to {file_path}")
36
+ return result
37
+ return wrapper
38
+ return decorator
39
+
40
  @st.cache_resource
41
  def vector_compressor_from_config():
42
+ # Return UMAP with 2 components for dimensionality reduction
 
43
  return UMAP(n_components=2)
44
 
45
+ # Caching the dataframe since loading from an external source can be time-consuming
46
  @st.cache_data
47
  def load_data():
48
  return pd.read_csv("https://huggingface.co/datasets/marksverdhei/reddit-syac-urls/resolve/main/train.csv")
 
72
 
73
  embeddings = load_embeddings()
74
 
75
+ # Custom cache the UMAP reduction using file_cache decorator
76
+ @file_cache(".cache/reducer_embeddings.pickle")
77
+ @st.cache_data
78
  def reduce_embeddings(embeddings):
79
  reducer = vector_compressor_from_config()
80
  return reducer.fit_transform(embeddings), reducer
 
102
  # Display the scatterplot and capture click events
103
  selected_points = plotly_events(fig, click_event=True, hover_event=False, override_height=600, override_width="100%")
104
 
 
105
  # If a point is clicked, handle the embedding inversion
106
  if selected_points:
 
107
  clicked_point = selected_points[0]
108
  x_coord = x = clicked_point['x']
109
  y_coord = y = clicked_point['y']
110
 
 
111
  inferred_embedding = reducer.inverse_transform(np.array([[x, y]]) if not isinstance(reducer, UMAP) else np.array([[x, y]]))
112
 
113
  inferred_embedding = inferred_embedding.astype("float32")
 
121
  st.text(str(output))
122
  st.text(str(inferred_embedding))
123
  else:
124
+ st.text("Click on a point in the scatterplot to see its coordinates.")