stefanoviel commited on
Commit
ce35c00
·
1 Parent(s): 70f287c

using tmp folder

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +31 -17
src/streamlit_app.py CHANGED
@@ -1,20 +1,18 @@
1
  import os
2
-
3
  import streamlit as st
4
  import pandas as pd
5
  from sentence_transformers import SentenceTransformer, util
6
  import torch
7
- from spellchecker import SpellChecker # Import the spellchecker library
8
  from io import StringIO
9
 
10
  # --- Configuration ---
11
  EMBEDDING_MODEL = 'sentence-transformers/all-MiniLM-L6-v2'
12
- EMBEDDINGS_FILE = 'paper_embeddings.pt'
13
- DATA_FILE = 'papers_data.pkl'
 
14
 
15
  # --- Data Loading and Preparation ---
16
- # This is the raw data provided by the user.
17
- # In a real application, you might load this from a CSV file.
18
  CSV_FILE = 'papers_with_abstracts_parallel.csv'
19
 
20
  # --- Caching Functions ---
@@ -41,10 +39,14 @@ def create_and_save_embeddings(model, data_df):
41
  # Generate embeddings
42
  corpus_embeddings = model.encode(data_df['text_to_embed'].tolist(), convert_to_tensor=True, show_progress_bar=True)
43
 
44
- # Save embeddings and dataframe
45
- torch.save(corpus_embeddings, EMBEDDINGS_FILE)
46
- data_df.to_pickle(DATA_FILE)
47
- st.success("Embeddings and data saved successfully!")
 
 
 
 
48
  return corpus_embeddings, data_df
49
 
50
  def load_data_and_embeddings():
@@ -53,13 +55,26 @@ def load_data_and_embeddings():
53
  If files don't exist, it calls the creation function.
54
  """
55
  model = load_embedding_model()
 
 
56
  if os.path.exists(EMBEDDINGS_FILE) and os.path.exists(DATA_FILE):
57
- corpus_embeddings = torch.load(EMBEDDINGS_FILE)
58
- data_df = pd.read_pickle(DATA_FILE)
59
- else:
60
- # Load the raw data from the string
 
 
 
 
 
61
  data_df = pd.read_csv(CSV_FILE)
62
  corpus_embeddings, data_df = create_and_save_embeddings(model, data_df)
 
 
 
 
 
 
63
 
64
  return model, corpus_embeddings, data_df
65
 
@@ -91,7 +106,6 @@ def correct_query_spelling(query, spell_checker):
91
 
92
  return " ".join(corrected_words)
93
 
94
-
95
  def semantic_search(query, model, corpus_embeddings, data_df, top_k=10):
96
  """
97
  Performs semantic search on the loaded data.
@@ -142,7 +156,7 @@ try:
142
  with col1:
143
  search_query = st.text_input(
144
  "Enter your search query:",
145
- placeholder="e.g., maschine lerning modles for time series"
146
  )
147
  with col2:
148
  top_k_results = st.number_input(
@@ -187,4 +201,4 @@ try:
187
 
188
  except Exception as e:
189
  st.error(f"An error occurred: {e}")
190
- st.info("Please ensure all required libraries are installed (`pip install streamlit pandas sentence-transformers torch pyspellchecker`) and try again.")
 
1
  import os
 
2
  import streamlit as st
3
  import pandas as pd
4
  from sentence_transformers import SentenceTransformer, util
5
  import torch
6
+ from spellchecker import SpellChecker
7
  from io import StringIO
8
 
9
  # --- Configuration ---
10
  EMBEDDING_MODEL = 'sentence-transformers/all-MiniLM-L6-v2'
11
+ # Use /tmp directory for temporary files in Hugging Face Spaces
12
+ EMBEDDINGS_FILE = '/tmp/paper_embeddings.pt'
13
+ DATA_FILE = '/tmp/papers_data.pkl'
14
 
15
  # --- Data Loading and Preparation ---
 
 
16
  CSV_FILE = 'papers_with_abstracts_parallel.csv'
17
 
18
  # --- Caching Functions ---
 
39
  # Generate embeddings
40
  corpus_embeddings = model.encode(data_df['text_to_embed'].tolist(), convert_to_tensor=True, show_progress_bar=True)
41
 
42
+ # Save embeddings and dataframe to /tmp directory
43
+ try:
44
+ torch.save(corpus_embeddings, EMBEDDINGS_FILE)
45
+ data_df.to_pickle(DATA_FILE)
46
+ st.success("Embeddings and data saved successfully!")
47
+ except Exception as e:
48
+ st.warning(f"Could not save embeddings to disk: {e}. Will regenerate on each session.")
49
+
50
  return corpus_embeddings, data_df
51
 
52
  def load_data_and_embeddings():
 
55
  If files don't exist, it calls the creation function.
56
  """
57
  model = load_embedding_model()
58
+
59
+ # Check if files exist and are readable
60
  if os.path.exists(EMBEDDINGS_FILE) and os.path.exists(DATA_FILE):
61
+ try:
62
+ corpus_embeddings = torch.load(EMBEDDINGS_FILE)
63
+ data_df = pd.read_pickle(DATA_FILE)
64
+ return model, corpus_embeddings, data_df
65
+ except Exception as e:
66
+ st.warning(f"Could not load saved embeddings: {e}. Regenerating...")
67
+
68
+ # Load the raw data from CSV
69
+ try:
70
  data_df = pd.read_csv(CSV_FILE)
71
  corpus_embeddings, data_df = create_and_save_embeddings(model, data_df)
72
+ except FileNotFoundError:
73
+ st.error(f"CSV file '{CSV_FILE}' not found. Please ensure it's in your repository.")
74
+ st.stop()
75
+ except Exception as e:
76
+ st.error(f"Error loading data: {e}")
77
+ st.stop()
78
 
79
  return model, corpus_embeddings, data_df
80
 
 
106
 
107
  return " ".join(corrected_words)
108
 
 
109
  def semantic_search(query, model, corpus_embeddings, data_df, top_k=10):
110
  """
111
  Performs semantic search on the loaded data.
 
156
  with col1:
157
  search_query = st.text_input(
158
  "Enter your search query:",
159
+ placeholder="e.g., machine learning models for time series"
160
  )
161
  with col2:
162
  top_k_results = st.number_input(
 
201
 
202
  except Exception as e:
203
  st.error(f"An error occurred: {e}")
204
+ st.info("Please ensure all required libraries are installed and the CSV file is present in your repository.")