stefanoviel commited on
Commit
0fd8f7a
·
1 Parent(s): 1a67af9

caching again

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +36 -26
src/streamlit_app.py CHANGED
@@ -20,7 +20,6 @@ CSV_FILE = 'papers_with_abstracts_parallel.csv'
20
 
21
 
22
  # --- Caching Functions ---
23
- # --- Caching Functions (Unchanged but crucial) ---
24
  @st.cache_resource
25
  def load_embedding_model():
26
  """Loads the Sentence Transformer model and caches it."""
@@ -35,56 +34,57 @@ def load_spell_checker():
35
  def create_and_save_embeddings(model, data_df):
36
  """
37
  Generates and saves document embeddings and the dataframe.
38
- This function is called only once if the files don't exist in the persistent directory.
39
  """
40
  st.info("First time setup: Generating and saving embeddings. This may take a moment...")
41
- data_df['text_to_embed'] = data_df['title'].fillna('') + ". " + data_df['abstract'].fillna('')
 
42
 
43
- corpus_embeddings = model.encode(
44
- data_df['text_to_embed'].tolist(),
45
- convert_to_tensor=True,
46
- show_progress_bar=True
47
- )
48
 
 
49
  try:
50
  torch.save(corpus_embeddings, EMBEDDINGS_FILE)
51
  data_df.to_pickle(DATA_FILE)
52
- st.success("Embeddings and data saved successfully for future sessions!")
53
  except Exception as e:
54
- st.warning(f"Could not save embeddings to persistent storage: {e}. Will regenerate on next session.")
55
 
56
  return corpus_embeddings, data_df
57
 
58
  @st.cache_data
59
  def load_data_and_embeddings():
60
  """
61
- Loads data and embeddings. It first tries to load from the persistent directory.
62
- If files don't exist, it creates them. The results are cached for the current session.
63
  """
64
  model = load_embedding_model()
65
 
66
- if DATA_FILE.exists() and EMBEDDINGS_FILE.exists():
 
67
  try:
68
- data_df = pd.read_pickle(DATA_FILE)
69
  corpus_embeddings = torch.load(EMBEDDINGS_FILE)
 
70
  return model, corpus_embeddings, data_df
71
  except Exception as e:
72
- st.warning(f"Could not load saved files: {e}. Regenerating...")
73
 
74
- # Fallback to creating embeddings if they don't exist
 
 
75
  try:
76
  data_df = pd.read_csv(CSV_FILE)
77
  corpus_embeddings, data_df = create_and_save_embeddings(model, data_df)
78
  except FileNotFoundError:
79
- st.error(f"The required data file '{CSV_FILE}' was not found. Please make sure it's in your repository root.")
80
  st.stop()
81
  except Exception as e:
82
- st.error(f"An unexpected error occurred while loading data: {e}")
83
  st.stop()
84
 
85
  return model, corpus_embeddings, data_df
86
 
87
- # ... (The rest of your functions `correct_query_spelling` and `semantic_search` remain the same) ...
88
  def correct_query_spelling(query, spell_checker):
89
  """
90
  Corrects potential spelling mistakes in the user's query.
@@ -153,13 +153,12 @@ The search is performed by comparing the semantic meaning of your query with the
153
  Spelling mistakes in your query will be automatically corrected.
154
  """)
155
 
156
- # --- App Logic ---
157
  try:
158
- # Load all necessary data using the corrected function
159
  model, corpus_embeddings, data_df = load_data_and_embeddings()
160
  spell_checker = load_spell_checker()
161
 
162
- # --- User Inputs ---
163
  col1, col2 = st.columns([4, 1])
164
  with col1:
165
  search_query = st.text_input(
@@ -170,26 +169,37 @@ try:
170
  top_k_results = st.number_input(
171
  "Number of results",
172
  min_value=1,
173
- max_value=100,
174
  value=10,
175
  help="Select the number of top results to display."
176
  )
177
 
178
  if search_query:
 
179
  corrected_query = correct_query_spelling(search_query, spell_checker)
180
 
 
181
  if corrected_query.lower() != search_query.lower():
182
  st.info(f"Did you mean: **{corrected_query}**? \n\n*Showing results for the corrected query.*")
183
 
184
- search_results = semantic_search(corrected_query, model, corpus_embeddings, data_df, top_k=top_k_results)
 
 
 
185
 
186
- st.subheader(f"Found {len(search_results)} results for '{corrected_query}'")
187
 
 
188
  if search_results:
189
  for result in search_results:
190
  with st.container(border=True):
 
191
  st.markdown(f"### [{result['title']}]({result['link']})")
 
 
192
  st.caption(f"**Authors:** {result['authors']}")
 
 
193
  if pd.notna(result['abstract']):
194
  with st.expander("View Abstract"):
195
  st.write(result['abstract'])
@@ -197,5 +207,5 @@ try:
197
  st.warning("No results found. Try a different query.")
198
 
199
  except Exception as e:
200
- st.error(f"An error occurred during app execution: {e}")
201
  st.info("Please ensure all required libraries are installed and the CSV file is present in your repository.")
 
20
 
21
 
22
  # --- Caching Functions ---
 
23
  @st.cache_resource
24
  def load_embedding_model():
25
  """Loads the Sentence Transformer model and caches it."""
 
34
  def create_and_save_embeddings(model, data_df):
35
  """
36
  Generates and saves document embeddings and the dataframe.
37
+ This function is called only once if the files don't exist.
38
  """
39
  st.info("First time setup: Generating and saving embeddings. This may take a moment...")
40
+ # Combine title and abstract for richer embeddings
41
+ data_df['text_to_embed'] = data_df['title'] + ". " + data_df['abstract'].fillna('')
42
 
43
+ # Generate embeddings
44
+ corpus_embeddings = model.encode(data_df['text_to_embed'].tolist(), convert_to_tensor=True, show_progress_bar=True)
 
 
 
45
 
46
+ # Save embeddings and dataframe to /tmp directory
47
  try:
48
  torch.save(corpus_embeddings, EMBEDDINGS_FILE)
49
  data_df.to_pickle(DATA_FILE)
50
+ st.success("Embeddings and data saved successfully!")
51
  except Exception as e:
52
+ st.warning(f"Could not save embeddings to disk: {e}. Will regenerate on each session.")
53
 
54
  return corpus_embeddings, data_df
55
 
56
  @st.cache_data
57
  def load_data_and_embeddings():
58
  """
59
+ Loads the saved embeddings and dataframe from disk.
60
+ If files don't exist, it calls the creation function.
61
  """
62
  model = load_embedding_model()
63
 
64
+ # Check if files exist and are readable
65
+ if os.path.exists(EMBEDDINGS_FILE) and os.path.exists(DATA_FILE):
66
  try:
 
67
  corpus_embeddings = torch.load(EMBEDDINGS_FILE)
68
+ data_df = pd.read_pickle(DATA_FILE)
69
  return model, corpus_embeddings, data_df
70
  except Exception as e:
71
+ st.warning(f"Could not load saved embeddings: {e}. Regenerating...")
72
 
73
+ st.info("embeding model path exists: " + str(Path(EMBEDDING_MODEL).exists()))
74
+
75
+ # Load the raw data from CSV
76
  try:
77
  data_df = pd.read_csv(CSV_FILE)
78
  corpus_embeddings, data_df = create_and_save_embeddings(model, data_df)
79
  except FileNotFoundError:
80
+ st.error(f"CSV file '{CSV_FILE}' not found. Please ensure it's in your repository.")
81
  st.stop()
82
  except Exception as e:
83
+ st.error(f"Error loading data: {e}")
84
  st.stop()
85
 
86
  return model, corpus_embeddings, data_df
87
 
 
88
  def correct_query_spelling(query, spell_checker):
89
  """
90
  Corrects potential spelling mistakes in the user's query.
 
153
  Spelling mistakes in your query will be automatically corrected.
154
  """)
155
 
156
+ # Load all necessary data
157
  try:
 
158
  model, corpus_embeddings, data_df = load_data_and_embeddings()
159
  spell_checker = load_spell_checker()
160
 
161
+ # --- User Inputs: Search Bar and Slider ---
162
  col1, col2 = st.columns([4, 1])
163
  with col1:
164
  search_query = st.text_input(
 
169
  top_k_results = st.number_input(
170
  "Number of results",
171
  min_value=1,
172
+ max_value=100, # Set a reasonable max
173
  value=10,
174
  help="Select the number of top results to display."
175
  )
176
 
177
  if search_query:
178
+ # --- Perform Typo Correction ---
179
  corrected_query = correct_query_spelling(search_query, spell_checker)
180
 
181
+ # If a correction was made, notify the user
182
  if corrected_query.lower() != search_query.lower():
183
  st.info(f"Did you mean: **{corrected_query}**? \n\n*Showing results for the corrected query.*")
184
 
185
+ final_query = corrected_query
186
+
187
+ # --- Perform Search ---
188
+ search_results = semantic_search(final_query, model, corpus_embeddings, data_df, top_k=top_k_results)
189
 
190
+ st.subheader(f"Found {len(search_results)} results for '{final_query}'")
191
 
192
+ # --- Display Results ---
193
  if search_results:
194
  for result in search_results:
195
  with st.container(border=True):
196
+ # Title as a clickable link
197
  st.markdown(f"### [{result['title']}]({result['link']})")
198
+
199
+ # Authors
200
  st.caption(f"**Authors:** {result['authors']}")
201
+
202
+ # Expander for the abstract
203
  if pd.notna(result['abstract']):
204
  with st.expander("View Abstract"):
205
  st.write(result['abstract'])
 
207
  st.warning("No results found. Try a different query.")
208
 
209
  except Exception as e:
210
+ st.error(f"An error occurred: {e}")
211
  st.info("Please ensure all required libraries are installed and the CSV file is present in your repository.")