stefanoviel
commited on
Commit
·
ce35c00
1
Parent(s):
70f287c
using tmp folder
Browse files- src/streamlit_app.py +31 -17
src/streamlit_app.py
CHANGED
@@ -1,20 +1,18 @@
|
|
1 |
import os
|
2 |
-
|
3 |
import streamlit as st
|
4 |
import pandas as pd
|
5 |
from sentence_transformers import SentenceTransformer, util
|
6 |
import torch
|
7 |
-
from spellchecker import SpellChecker
|
8 |
from io import StringIO
|
9 |
|
10 |
# --- Configuration ---
|
11 |
EMBEDDING_MODEL = 'sentence-transformers/all-MiniLM-L6-v2'
|
12 |
-
|
13 |
-
|
|
|
14 |
|
15 |
# --- Data Loading and Preparation ---
|
16 |
-
# This is the raw data provided by the user.
|
17 |
-
# In a real application, you might load this from a CSV file.
|
18 |
CSV_FILE = 'papers_with_abstracts_parallel.csv'
|
19 |
|
20 |
# --- Caching Functions ---
|
@@ -41,10 +39,14 @@ def create_and_save_embeddings(model, data_df):
|
|
41 |
# Generate embeddings
|
42 |
corpus_embeddings = model.encode(data_df['text_to_embed'].tolist(), convert_to_tensor=True, show_progress_bar=True)
|
43 |
|
44 |
-
# Save embeddings and dataframe
|
45 |
-
|
46 |
-
|
47 |
-
|
|
|
|
|
|
|
|
|
48 |
return corpus_embeddings, data_df
|
49 |
|
50 |
def load_data_and_embeddings():
|
@@ -53,13 +55,26 @@ def load_data_and_embeddings():
|
|
53 |
If files don't exist, it calls the creation function.
|
54 |
"""
|
55 |
model = load_embedding_model()
|
|
|
|
|
56 |
if os.path.exists(EMBEDDINGS_FILE) and os.path.exists(DATA_FILE):
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
|
|
|
|
|
|
|
|
|
|
61 |
data_df = pd.read_csv(CSV_FILE)
|
62 |
corpus_embeddings, data_df = create_and_save_embeddings(model, data_df)
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
|
64 |
return model, corpus_embeddings, data_df
|
65 |
|
@@ -91,7 +106,6 @@ def correct_query_spelling(query, spell_checker):
|
|
91 |
|
92 |
return " ".join(corrected_words)
|
93 |
|
94 |
-
|
95 |
def semantic_search(query, model, corpus_embeddings, data_df, top_k=10):
|
96 |
"""
|
97 |
Performs semantic search on the loaded data.
|
@@ -142,7 +156,7 @@ try:
|
|
142 |
with col1:
|
143 |
search_query = st.text_input(
|
144 |
"Enter your search query:",
|
145 |
-
placeholder="e.g.,
|
146 |
)
|
147 |
with col2:
|
148 |
top_k_results = st.number_input(
|
@@ -187,4 +201,4 @@ try:
|
|
187 |
|
188 |
except Exception as e:
|
189 |
st.error(f"An error occurred: {e}")
|
190 |
-
st.info("Please ensure all required libraries are installed
|
|
|
1 |
import os
|
|
|
2 |
import streamlit as st
|
3 |
import pandas as pd
|
4 |
from sentence_transformers import SentenceTransformer, util
|
5 |
import torch
|
6 |
+
from spellchecker import SpellChecker
|
7 |
from io import StringIO
|
8 |
|
9 |
# --- Configuration ---
|
10 |
EMBEDDING_MODEL = 'sentence-transformers/all-MiniLM-L6-v2'
|
11 |
+
# Use /tmp directory for temporary files in Hugging Face Spaces
|
12 |
+
EMBEDDINGS_FILE = '/tmp/paper_embeddings.pt'
|
13 |
+
DATA_FILE = '/tmp/papers_data.pkl'
|
14 |
|
15 |
# --- Data Loading and Preparation ---
|
|
|
|
|
16 |
CSV_FILE = 'papers_with_abstracts_parallel.csv'
|
17 |
|
18 |
# --- Caching Functions ---
|
|
|
39 |
# Generate embeddings
|
40 |
corpus_embeddings = model.encode(data_df['text_to_embed'].tolist(), convert_to_tensor=True, show_progress_bar=True)
|
41 |
|
42 |
+
# Save embeddings and dataframe to /tmp directory
|
43 |
+
try:
|
44 |
+
torch.save(corpus_embeddings, EMBEDDINGS_FILE)
|
45 |
+
data_df.to_pickle(DATA_FILE)
|
46 |
+
st.success("Embeddings and data saved successfully!")
|
47 |
+
except Exception as e:
|
48 |
+
st.warning(f"Could not save embeddings to disk: {e}. Will regenerate on each session.")
|
49 |
+
|
50 |
return corpus_embeddings, data_df
|
51 |
|
52 |
def load_data_and_embeddings():
|
|
|
55 |
If files don't exist, it calls the creation function.
|
56 |
"""
|
57 |
model = load_embedding_model()
|
58 |
+
|
59 |
+
# Check if files exist and are readable
|
60 |
if os.path.exists(EMBEDDINGS_FILE) and os.path.exists(DATA_FILE):
|
61 |
+
try:
|
62 |
+
corpus_embeddings = torch.load(EMBEDDINGS_FILE)
|
63 |
+
data_df = pd.read_pickle(DATA_FILE)
|
64 |
+
return model, corpus_embeddings, data_df
|
65 |
+
except Exception as e:
|
66 |
+
st.warning(f"Could not load saved embeddings: {e}. Regenerating...")
|
67 |
+
|
68 |
+
# Load the raw data from CSV
|
69 |
+
try:
|
70 |
data_df = pd.read_csv(CSV_FILE)
|
71 |
corpus_embeddings, data_df = create_and_save_embeddings(model, data_df)
|
72 |
+
except FileNotFoundError:
|
73 |
+
st.error(f"CSV file '{CSV_FILE}' not found. Please ensure it's in your repository.")
|
74 |
+
st.stop()
|
75 |
+
except Exception as e:
|
76 |
+
st.error(f"Error loading data: {e}")
|
77 |
+
st.stop()
|
78 |
|
79 |
return model, corpus_embeddings, data_df
|
80 |
|
|
|
106 |
|
107 |
return " ".join(corrected_words)
|
108 |
|
|
|
109 |
def semantic_search(query, model, corpus_embeddings, data_df, top_k=10):
|
110 |
"""
|
111 |
Performs semantic search on the loaded data.
|
|
|
156 |
with col1:
|
157 |
search_query = st.text_input(
|
158 |
"Enter your search query:",
|
159 |
+
placeholder="e.g., machine learning models for time series"
|
160 |
)
|
161 |
with col2:
|
162 |
top_k_results = st.number_input(
|
|
|
201 |
|
202 |
except Exception as e:
|
203 |
st.error(f"An error occurred: {e}")
|
204 |
+
st.info("Please ensure all required libraries are installed and the CSV file is present in your repository.")
|