seanpedrickcase
commited on
Commit
•
a95ef9f
1
Parent(s):
2393537
General code improvements and refinements.
Browse files- Dockerfile +0 -2
- app.py +36 -44
- requirements.txt +2 -3
- requirements_gpu.txt +3 -3
- search_funcs/bm25_functions.py +200 -77
- search_funcs/helper_functions.py +35 -6
- search_funcs/semantic_functions.py +108 -396
- search_funcs/spacy_search_funcs.py +6 -1
Dockerfile
CHANGED
@@ -58,7 +58,5 @@ WORKDIR $HOME/app
|
|
58 |
|
59 |
# Copy the current directory contents into the container at $HOME/app setting the owner to the user
|
60 |
COPY --chown=user . $HOME/app
|
61 |
-
#COPY . $HOME/app
|
62 |
-
|
63 |
|
64 |
CMD ["python", "app.py"]
|
|
|
58 |
|
59 |
# Copy the current directory contents into the container at $HOME/app setting the owner to the user
|
60 |
COPY --chown=user . $HOME/app
|
|
|
|
|
61 |
|
62 |
CMD ["python", "app.py"]
|
app.py
CHANGED
@@ -7,7 +7,7 @@ PandasDataFrame = Type[pd.DataFrame]
|
|
7 |
|
8 |
from search_funcs.bm25_functions import prepare_bm25_input_data, prepare_bm25, bm25_search
|
9 |
from search_funcs.semantic_ingest_functions import csv_excel_text_to_docs
|
10 |
-
from search_funcs.semantic_functions import docs_to_bge_embed_np_array,
|
11 |
from search_funcs.helper_functions import display_info, initial_data_load, put_columns_in_join_df, get_temp_folder_path, empty_folder, get_connection_params, output_folder
|
12 |
from search_funcs.spacy_search_funcs import spacy_fuzzy_search
|
13 |
from search_funcs.aws_functions import load_data_from_aws
|
@@ -17,39 +17,33 @@ temp_folder_path = get_temp_folder_path()
|
|
17 |
empty_folder(temp_folder_path)
|
18 |
|
19 |
## Gradio app - BM25 search
|
20 |
-
|
21 |
|
22 |
-
|
23 |
-
with block:
|
24 |
print("Please don't close this window! Open the below link in the web browser of your choice.")
|
25 |
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
bm25_search_object_state = gr.State()
|
34 |
-
|
35 |
-
k_val = gr.State(9999)
|
36 |
-
out_passages = gr.State(9999)
|
37 |
-
vec_weight = gr.State(1)
|
38 |
-
|
39 |
-
corpus_state = gr.State()
|
40 |
-
keyword_data_list_state = gr.State([])
|
41 |
-
join_data_state = gr.State(pd.DataFrame())
|
42 |
-
output_file_state = gr.State([])
|
43 |
-
|
44 |
-
orig_keyword_data_state = gr.State(pd.DataFrame())
|
45 |
-
keyword_data_state = gr.State(pd.DataFrame())
|
46 |
-
|
47 |
-
orig_semantic_data_state = gr.State(pd.DataFrame())
|
48 |
-
semantic_data_state = gr.State(pd.DataFrame())
|
49 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
session_hash_state = gr.State("")
|
51 |
s3_output_folder_state = gr.State("")
|
|
|
|
|
52 |
|
|
|
53 |
in_k1_info = gr.State("""k1: Constant used for influencing the term frequency saturation. After saturation is reached, additional
|
54 |
presence for the term adds a significantly less additional score. According to [1]_, experiments suggest
|
55 |
that 1.2 < k1 < 2 yields reasonably good results, although the optimal value depends on factors such as
|
@@ -167,7 +161,7 @@ depends on factors such as the type of documents or queries. Information taken f
|
|
167 |
out_aws_data_message = gr.Textbox(label="AWS data load progress")
|
168 |
|
169 |
# Changing search parameters button
|
170 |
-
in_search_param_button.click(fn=prepare_bm25, inputs=[
|
171 |
|
172 |
# ---
|
173 |
in_k1_button.click(display_info, inputs=in_k1_info)
|
@@ -178,43 +172,41 @@ depends on factors such as the type of documents or queries. Information taken f
|
|
178 |
### Loading AWS data ###
|
179 |
load_aws_keyword_data_button.click(fn=load_data_from_aws, inputs=[in_aws_keyword_file, aws_password_box], outputs=[in_bm25_file, out_aws_data_message])
|
180 |
load_aws_semantic_data_button.click(fn=load_data_from_aws, inputs=[in_aws_semantic_file, aws_password_box], outputs=[in_semantic_file, out_aws_data_message])
|
181 |
-
|
182 |
|
183 |
### BM25 SEARCH ###
|
184 |
# Update dropdowns upon initial file load
|
185 |
-
in_bm25_file.change(initial_data_load, inputs=[in_bm25_file], outputs=[in_bm25_column, search_df_join_column,
|
186 |
in_join_file.change(put_columns_in_join_df, inputs=[in_join_file], outputs=[in_join_column, join_data_state, in_join_message])
|
187 |
|
188 |
# Load in BM25 data
|
189 |
-
load_bm25_data_button.click(fn=prepare_bm25_input_data, inputs=[in_bm25_file, in_bm25_column,
|
190 |
-
then(fn=prepare_bm25, inputs=[
|
191 |
-
|
192 |
|
193 |
# BM25 search functions on click or enter
|
194 |
-
keyword_search_button.click(fn=bm25_search, inputs=[keyword_query, in_no_search_results, orig_keyword_data_state,
|
195 |
-
keyword_query.submit(fn=bm25_search, inputs=[keyword_query, in_no_search_results, orig_keyword_data_state,
|
196 |
|
197 |
# Fuzzy search functions on click
|
198 |
-
fuzzy_search_button.click(fn=spacy_fuzzy_search, inputs=[keyword_query,
|
199 |
|
200 |
### SEMANTIC SEARCH ###
|
201 |
|
202 |
# Load in a csv/excel file for semantic search
|
203 |
-
in_semantic_file.change(initial_data_load, inputs=[in_semantic_file], outputs=[in_semantic_column, search_df_join_column, semantic_data_state, orig_semantic_data_state,
|
204 |
load_semantic_data_button.click(
|
205 |
-
csv_excel_text_to_docs, inputs=[semantic_data_state, in_semantic_file, in_semantic_column, in_clean_data, return_intermediate_files], outputs=[
|
206 |
-
then(docs_to_bge_embed_np_array, inputs=[
|
207 |
|
208 |
# Semantic search query
|
209 |
-
semantic_submit.click(
|
210 |
-
semantic_query.submit(
|
211 |
|
212 |
-
|
213 |
|
214 |
# Launch the Gradio app
|
215 |
if __name__ == "__main__":
|
216 |
-
|
217 |
|
218 |
# Running on local server with https: https://discuss.huggingface.co/t/how-to-run-gradio-with-0-0-0-0-and-https/38003 or https://dev.to/rajshirolkar/fastapi-over-https-for-development-on-windows-2p7d # Need to download OpenSSL and create own keys
|
219 |
-
#
|
220 |
# ssl_certfile="cert.pem", ssl_keyfile="key.pem") # port 443 for https. Certificates currently not valid
|
|
|
7 |
|
8 |
from search_funcs.bm25_functions import prepare_bm25_input_data, prepare_bm25, bm25_search
|
9 |
from search_funcs.semantic_ingest_functions import csv_excel_text_to_docs
|
10 |
+
from search_funcs.semantic_functions import docs_to_bge_embed_np_array, bge_semantic_search
|
11 |
from search_funcs.helper_functions import display_info, initial_data_load, put_columns_in_join_df, get_temp_folder_path, empty_folder, get_connection_params, output_folder
|
12 |
from search_funcs.spacy_search_funcs import spacy_fuzzy_search
|
13 |
from search_funcs.aws_functions import load_data_from_aws
|
|
|
17 |
empty_folder(temp_folder_path)
|
18 |
|
19 |
## Gradio app - BM25 search
|
20 |
+
app = gr.Blocks(theme = gr.themes.Base()) # , css="theme.css"
|
21 |
|
22 |
+
with app:
|
|
|
23 |
print("Please don't close this window! Open the below link in the web browser of your choice.")
|
24 |
|
25 |
+
# BM25 state objects
|
26 |
+
orig_keyword_data_state = gr.State(pd.DataFrame()) # Original data that is not changed #gr.Dataframe(pd.DataFrame(),visible=False) #gr.State(pd.DataFrame())
|
27 |
+
prepared_keyword_data_state = gr.State(pd.DataFrame()) # Data frame the contains modified data #gr.Dataframe(pd.DataFrame(),visible=False) #gr.State(pd.DataFrame())
|
28 |
+
#tokenised_prepared_keyword_data_state = gr.State([]) # This is data that has been loaded in as tokens #gr.Dataframe(pd.DataFrame(),visible=False) #gr.State()
|
29 |
+
tokenised_prepared_keyword_data_state = gr.State([]) # Data that has been prepared for search (tokenised) #gr.Dataframe(np.array([]), type="array", visible=False) #gr.State([])
|
30 |
+
bm25_search_index_state = gr.State()
|
31 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
|
33 |
+
# Semantic search state objects
|
34 |
+
orig_semantic_data_state = gr.State(pd.DataFrame()) #gr.Dataframe(pd.DataFrame(),visible=False) # gr.State(pd.DataFrame())
|
35 |
+
semantic_data_state = gr.State(pd.DataFrame()) #gr.Dataframe(pd.DataFrame(),visible=False) # gr.State(pd.DataFrame())
|
36 |
+
semantic_input_document_format = gr.State([])
|
37 |
+
embeddings_state = gr.State(np.array([])) #gr.Dataframe(np.array([]), type="numpy", visible=False) #gr.State(np.array([])) # globals()["embeddings"]
|
38 |
+
semantic_k_val = gr.Number(9999, visible=False)
|
39 |
+
|
40 |
+
# State objects for app in general
|
41 |
session_hash_state = gr.State("")
|
42 |
s3_output_folder_state = gr.State("")
|
43 |
+
join_data_state = gr.State(pd.DataFrame()) #gr.Dataframe(pd.DataFrame(), visible=False) #gr.State(pd.DataFrame())
|
44 |
+
output_file_state = gr.Dropdown([], visible=False, allow_custom_value=True) #gr.Dataframe(type="array", visible=False) #gr.State([])
|
45 |
|
46 |
+
# Informational state objects
|
47 |
in_k1_info = gr.State("""k1: Constant used for influencing the term frequency saturation. After saturation is reached, additional
|
48 |
presence for the term adds a significantly less additional score. According to [1]_, experiments suggest
|
49 |
that 1.2 < k1 < 2 yields reasonably good results, although the optimal value depends on factors such as
|
|
|
161 |
out_aws_data_message = gr.Textbox(label="AWS data load progress")
|
162 |
|
163 |
# Changing search parameters button
|
164 |
+
in_search_param_button.click(fn=prepare_bm25, inputs=[tokenised_prepared_keyword_data_state, in_bm25_file, in_bm25_column, bm25_search_index_state, return_intermediate_files, in_k1, in_b, in_alpha], outputs=[load_finished_message])
|
165 |
|
166 |
# ---
|
167 |
in_k1_button.click(display_info, inputs=in_k1_info)
|
|
|
172 |
### Loading AWS data ###
|
173 |
load_aws_keyword_data_button.click(fn=load_data_from_aws, inputs=[in_aws_keyword_file, aws_password_box], outputs=[in_bm25_file, out_aws_data_message])
|
174 |
load_aws_semantic_data_button.click(fn=load_data_from_aws, inputs=[in_aws_semantic_file, aws_password_box], outputs=[in_semantic_file, out_aws_data_message])
|
|
|
175 |
|
176 |
### BM25 SEARCH ###
|
177 |
# Update dropdowns upon initial file load
|
178 |
+
in_bm25_file.change(initial_data_load, inputs=[in_bm25_file], outputs=[in_bm25_column, search_df_join_column, prepared_keyword_data_state, orig_keyword_data_state, bm25_search_index_state, embeddings_state, tokenised_prepared_keyword_data_state, load_finished_message, current_source], api_name="initial_load")
|
179 |
in_join_file.change(put_columns_in_join_df, inputs=[in_join_file], outputs=[in_join_column, join_data_state, in_join_message])
|
180 |
|
181 |
# Load in BM25 data
|
182 |
+
load_bm25_data_button.click(fn=prepare_bm25_input_data, inputs=[in_bm25_file, in_bm25_column, prepared_keyword_data_state, tokenised_prepared_keyword_data_state, in_clean_data, return_intermediate_files], outputs=[tokenised_prepared_keyword_data_state, load_finished_message, prepared_keyword_data_state, output_file, output_file, in_bm25_column], api_name="load_keyword").\
|
183 |
+
then(fn=prepare_bm25, inputs=[tokenised_prepared_keyword_data_state, in_bm25_file, in_bm25_column, bm25_search_index_state, in_clean_data, return_intermediate_files, in_k1, in_b, in_alpha], outputs=[load_finished_message, output_file, bm25_search_index_state, tokenised_prepared_keyword_data_state], api_name="prepare_keyword") # keyword_data_list_state
|
|
|
184 |
|
185 |
# BM25 search functions on click or enter
|
186 |
+
keyword_search_button.click(fn=bm25_search, inputs=[keyword_query, in_no_search_results, orig_keyword_data_state, prepared_keyword_data_state, in_bm25_column, join_data_state, in_clean_data, bm25_search_index_state, tokenised_prepared_keyword_data_state, in_join_column, search_df_join_column, in_k1, in_b, in_alpha], outputs=[output_single_text, output_file], api_name="keyword_search")
|
187 |
+
keyword_query.submit(fn=bm25_search, inputs=[keyword_query, in_no_search_results, orig_keyword_data_state, prepared_keyword_data_state, in_bm25_column, join_data_state, in_clean_data, bm25_search_index_state, tokenised_prepared_keyword_data_state, in_join_column, search_df_join_column, in_k1, in_b, in_alpha], outputs=[output_single_text, output_file])
|
188 |
|
189 |
# Fuzzy search functions on click
|
190 |
+
fuzzy_search_button.click(fn=spacy_fuzzy_search, inputs=[keyword_query, tokenised_prepared_keyword_data_state, prepared_keyword_data_state, in_bm25_column, join_data_state, search_df_join_column, in_join_column, no_spelling_mistakes], outputs=[output_single_text, output_file], api_name="fuzzy_search")
|
191 |
|
192 |
### SEMANTIC SEARCH ###
|
193 |
|
194 |
# Load in a csv/excel file for semantic search
|
195 |
+
in_semantic_file.change(initial_data_load, inputs=[in_semantic_file], outputs=[in_semantic_column, search_df_join_column, semantic_data_state, orig_semantic_data_state, bm25_search_index_state, embeddings_state, tokenised_prepared_keyword_data_state, semantic_load_progress, current_source_semantic])
|
196 |
load_semantic_data_button.click(
|
197 |
+
csv_excel_text_to_docs, inputs=[semantic_data_state, in_semantic_file, in_semantic_column, in_clean_data, return_intermediate_files], outputs=[semantic_input_document_format, semantic_load_progress, output_file_state]).\
|
198 |
+
then(docs_to_bge_embed_np_array, inputs=[semantic_input_document_format, in_semantic_file, embeddings_state, output_file_state, in_clean_data, return_intermediate_files, embedding_super_compress], outputs=[semantic_load_progress, embeddings_state, semantic_output_file, output_file_state]) # vectorstore_state
|
199 |
|
200 |
# Semantic search query
|
201 |
+
semantic_submit.click(bge_semantic_search, inputs=[semantic_query, embeddings_state, semantic_input_document_format, semantic_k_val, semantic_min_distance, join_data_state, in_join_column, search_df_join_column], outputs=[semantic_output_single_text, semantic_output_file], api_name="semantic_search")
|
202 |
+
semantic_query.submit(bge_semantic_search, inputs=[semantic_query, embeddings_state, semantic_input_document_format, semantic_k_val, semantic_min_distance, join_data_state, in_join_column, search_df_join_column], outputs=[semantic_output_single_text, semantic_output_file])
|
203 |
|
204 |
+
app.load(get_connection_params, inputs=None, outputs=[session_hash_state, s3_output_folder_state])
|
205 |
|
206 |
# Launch the Gradio app
|
207 |
if __name__ == "__main__":
|
208 |
+
app.queue().launch(show_error=True) # root_path="/data-text-search" # server_name="0.0.0.0",
|
209 |
|
210 |
# Running on local server with https: https://discuss.huggingface.co/t/how-to-run-gradio-with-0-0-0-0-and-https/38003 or https://dev.to/rajshirolkar/fastapi-over-https-for-development-on-windows-2p7d # Need to download OpenSSL and create own keys
|
211 |
+
# app.queue().launch(ssl_verify=False, share=False, debug=False, server_name="0.0.0.0",server_port=443,
|
212 |
# ssl_certfile="cert.pem", ssl_keyfile="key.pem") # port 443 for https. Certificates currently not valid
|
requirements.txt
CHANGED
@@ -1,12 +1,11 @@
|
|
1 |
pandas==2.2.2
|
2 |
polars==0.20.3
|
3 |
pyarrow==14.0.2
|
4 |
-
openpyxl==3.1.
|
5 |
torch==2.3.1
|
6 |
-
transformers==4.41.2
|
7 |
spacy
|
8 |
en_core_web_sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.7.1/en_core_web_sm-3.7.1.tar.gz
|
9 |
gradio
|
10 |
sentence_transformers==3.0.1
|
11 |
-
lxml==5.
|
12 |
boto3==1.34.103
|
|
|
1 |
pandas==2.2.2
|
2 |
polars==0.20.3
|
3 |
pyarrow==14.0.2
|
4 |
+
openpyxl==3.1.3
|
5 |
torch==2.3.1
|
|
|
6 |
spacy
|
7 |
en_core_web_sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.7.1/en_core_web_sm-3.7.1.tar.gz
|
8 |
gradio
|
9 |
sentence_transformers==3.0.1
|
10 |
+
lxml==5.2.2
|
11 |
boto3==1.34.103
|
requirements_gpu.txt
CHANGED
@@ -1,11 +1,11 @@
|
|
1 |
pandas==2.2.2
|
2 |
polars==0.20.3
|
3 |
pyarrow==14.0.2
|
4 |
-
openpyxl==3.1.
|
5 |
torch==2.3.1 --index-url https://download.pytorch.org/whl/cu121
|
6 |
spacy
|
7 |
en_core_web_sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.7.1/en_core_web_sm-3.7.1.tar.gz
|
8 |
gradio
|
9 |
-
sentence_transformers==
|
10 |
-
lxml==5.
|
11 |
boto3==1.34.103
|
|
|
1 |
pandas==2.2.2
|
2 |
polars==0.20.3
|
3 |
pyarrow==14.0.2
|
4 |
+
openpyxl==3.1.3
|
5 |
torch==2.3.1 --index-url https://download.pytorch.org/whl/cu121
|
6 |
spacy
|
7 |
en_core_web_sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.7.1/en_core_web_sm-3.7.1.tar.gz
|
8 |
gradio
|
9 |
+
sentence_transformers==3.0.1
|
10 |
+
lxml==5.2.2
|
11 |
boto3==1.34.103
|
search_funcs/bm25_functions.py
CHANGED
@@ -8,6 +8,7 @@ import time
|
|
8 |
import pandas as pd
|
9 |
from numpy import inf
|
10 |
import gradio as gr
|
|
|
11 |
|
12 |
from datetime import datetime
|
13 |
|
@@ -165,7 +166,7 @@ class BM25:
|
|
165 |
return [documents[i] for i in heapq.nlargest(n, scores.keys(), key=scores.__getitem__)]
|
166 |
|
167 |
|
168 |
-
def get_top_n_with_score(self, query, documents, n=5):
|
169 |
"""
|
170 |
Retrieve the top n documents for the query along with their scores.
|
171 |
|
@@ -229,15 +230,47 @@ class BM25:
|
|
229 |
with open(f"{output_folder}{filename}.pkl", "rb") as fsave:
|
230 |
return pickle.load(fsave)
|
231 |
|
232 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
233 |
|
234 |
-
def prepare_bm25_input_data(in_file, text_column, data_state, tokenised_state, clean="No", return_intermediate_files = "No", progress=gr.Progress(track_tqdm=True)):
|
235 |
-
#print(in_file)
|
236 |
ensure_output_folder_exists(output_folder)
|
237 |
|
238 |
if not in_file:
|
239 |
print("No input file found. Please load in at least one file.")
|
240 |
-
return None, "No input file found. Please load in at least one file.", data_state, None, None,
|
241 |
|
242 |
progress(0, desc = "Loading in data")
|
243 |
file_list = [string.name for string in in_file]
|
@@ -247,25 +280,24 @@ def prepare_bm25_input_data(in_file, text_column, data_state, tokenised_state, c
|
|
247 |
data_file_names = [string for string in file_list if "tokenised" not in string.lower() and "npz" not in string.lower() and "gz" not in string.lower()]
|
248 |
|
249 |
if not data_file_names:
|
250 |
-
return None, "Please load in at least one csv/Excel/parquet data file.", data_state, None, None,
|
251 |
|
252 |
if not text_column:
|
253 |
-
return None, "Please enter a column name to search.", data_state, None, None,
|
254 |
|
255 |
data_file_name = data_file_names[0]
|
256 |
|
257 |
df = data_state #read_file(data_file_name)
|
258 |
-
data_file_out_name = get_file_path_end_with_ext(data_file_name)
|
259 |
data_file_out_name_no_ext = get_file_path_end(data_file_name)
|
260 |
|
261 |
-
## Load in pre-tokenised
|
262 |
-
tokenised_df = pd.DataFrame()
|
263 |
|
264 |
-
tokenised_file_names = [string for string in file_list if "tokenised" in string.lower()]
|
265 |
search_index_file_names = [string for string in file_list if "gz" in string.lower()]
|
266 |
|
267 |
-
|
268 |
-
|
269 |
df[text_column] = df[text_column].astype(str).str.lower()
|
270 |
|
271 |
if "copy_of_case_note_id" in df.columns:
|
@@ -273,10 +305,10 @@ def prepare_bm25_input_data(in_file, text_column, data_state, tokenised_state, c
|
|
273 |
df.loc[~df["copy_of_case_note_id"].isna(), text_column] = ""
|
274 |
|
275 |
if search_index_file_names:
|
276 |
-
|
277 |
message = "Tokenisation skipped - loading search index from file."
|
278 |
print(message)
|
279 |
-
return
|
280 |
|
281 |
|
282 |
if clean == "Yes":
|
@@ -285,11 +317,11 @@ def prepare_bm25_input_data(in_file, text_column, data_state, tokenised_state, c
|
|
285 |
print("Starting data clean.")
|
286 |
|
287 |
#df = df.drop_duplicates(text_column)
|
288 |
-
|
289 |
-
|
290 |
|
291 |
# Save to file if you have cleaned the data
|
292 |
-
out_file_name, text_column, df = save_prepared_bm25_data(data_file_name,
|
293 |
|
294 |
clean_toc = time.perf_counter()
|
295 |
clean_time_out = f"Cleaning the text took {clean_toc - clean_tic:0.1f} seconds."
|
@@ -297,7 +329,7 @@ def prepare_bm25_input_data(in_file, text_column, data_state, tokenised_state, c
|
|
297 |
|
298 |
else:
|
299 |
# Don't clean or save file to disk
|
300 |
-
|
301 |
print("No data cleaning performed")
|
302 |
out_file_name = None
|
303 |
|
@@ -305,24 +337,27 @@ def prepare_bm25_input_data(in_file, text_column, data_state, tokenised_state, c
|
|
305 |
|
306 |
progress(0.4, desc = "Tokenising text")
|
307 |
|
|
|
|
|
308 |
if tokenised_state:
|
309 |
-
|
310 |
-
corpus = tokenised_df.iloc[:,0].tolist()
|
311 |
print("Tokenised data loaded from file")
|
312 |
-
|
|
|
313 |
|
314 |
else:
|
315 |
tokeniser_tic = time.perf_counter()
|
316 |
-
|
317 |
batch_size = 256
|
318 |
-
for doc in tokenizer.pipe(progress.tqdm(
|
319 |
-
|
320 |
|
321 |
tokeniser_toc = time.perf_counter()
|
322 |
tokenizer_time_out = f"Tokenising the text took {tokeniser_toc - tokeniser_tic:0.1f} seconds."
|
323 |
print(tokenizer_time_out)
|
|
|
324 |
|
325 |
-
if len(
|
326 |
message = "Data loaded"
|
327 |
else:
|
328 |
message = "Data loaded. Warning: dataset may be too short to get consistent search results."
|
@@ -334,13 +369,29 @@ def prepare_bm25_input_data(in_file, text_column, data_state, tokenised_state, c
|
|
334 |
else:
|
335 |
tokenised_data_file_name = output_folder + data_file_out_name_no_ext + "_tokenised.parquet"
|
336 |
|
337 |
-
pd.DataFrame(data={"
|
|
|
|
|
|
|
|
|
338 |
|
339 |
-
|
|
|
|
|
|
|
|
|
|
|
340 |
|
341 |
-
|
|
|
|
|
|
|
|
|
|
|
342 |
|
343 |
-
|
|
|
|
|
344 |
|
345 |
ensure_output_folder_exists(output_folder)
|
346 |
|
@@ -368,26 +419,54 @@ def save_prepared_bm25_data(in_file_name, prepared_text_list, in_df, in_bm25_col
|
|
368 |
|
369 |
return file_name, new_text_column, prepared_df
|
370 |
|
371 |
-
def prepare_bm25(
|
372 |
-
|
373 |
-
|
374 |
-
|
375 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
376 |
|
377 |
if not in_file:
|
378 |
out_message ="No input file found. Please load in at least one file."
|
379 |
print(out_message)
|
380 |
-
return out_message, None
|
381 |
|
382 |
-
if not
|
383 |
out_message = "No data file found. Please load in at least one csv/Excel/Parquet file."
|
384 |
print(out_message)
|
385 |
-
return out_message, None
|
386 |
|
387 |
if not text_column:
|
388 |
out_message = "Please enter a column name to search."
|
389 |
print(out_message)
|
390 |
-
return out_message, None
|
391 |
|
392 |
file_list = [string.name for string in in_file]
|
393 |
|
@@ -397,36 +476,23 @@ def prepare_bm25(corpus, in_file, text_column, search_index, clean, return_inter
|
|
397 |
data_file_names = [string for string in file_list if "tokenised" not in string.lower() and "npz" not in string.lower() and "gz" not in string.lower()]
|
398 |
|
399 |
if not data_file_names:
|
400 |
-
return "Please load in at least one csv/Excel/parquet data file.", None
|
401 |
|
402 |
data_file_name = data_file_names[0]
|
403 |
data_file_out_name = get_file_path_end_with_ext(data_file_name)
|
404 |
data_file_name_no_ext = get_file_path_end(data_file_name)
|
405 |
|
406 |
-
# Check if there is a search index file already
|
407 |
-
#index_file_names = [string for string in file_list if "gz" in string.lower()]
|
408 |
-
|
409 |
progress(0.6, desc = "Preparing search index")
|
410 |
|
411 |
-
#if index_file_names:
|
412 |
if search_index:
|
413 |
-
|
414 |
-
|
415 |
-
#print(index_file_name)
|
416 |
-
|
417 |
-
bm25_load = search_index
|
418 |
-
|
419 |
-
|
420 |
-
#index_file_out_name = get_file_path_end_with_ext(index_file_name)
|
421 |
-
#index_file_name_no_ext = get_file_path_end(index_file_name)
|
422 |
-
|
423 |
else:
|
424 |
-
print("Preparing BM25 corpus")
|
425 |
|
426 |
-
|
427 |
|
428 |
-
global bm25
|
429 |
-
bm25 = bm25_load
|
430 |
|
431 |
if return_intermediate_files == "Yes":
|
432 |
print("Saving search index file")
|
@@ -451,7 +517,7 @@ def prepare_bm25(corpus, in_file, text_column, search_index, clean, return_inter
|
|
451 |
|
452 |
print(message)
|
453 |
|
454 |
-
return message, None, bm25
|
455 |
|
456 |
def convert_bm25_query_to_tokens(free_text_query, clean="No"):
|
457 |
'''
|
@@ -474,9 +540,75 @@ def convert_bm25_query_to_tokens(free_text_query, clean="No"):
|
|
474 |
|
475 |
return out_query
|
476 |
|
477 |
-
def bm25_search(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
478 |
|
479 |
progress(0, desc = "Conducting keyword search")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
480 |
|
481 |
# Prepare query
|
482 |
if (clean == "Yes") | (text_column.endswith("_cleaned")):
|
@@ -484,8 +616,6 @@ def bm25_search(free_text_query, in_no_search_results, original_data, searched_d
|
|
484 |
else:
|
485 |
token_query = convert_bm25_query_to_tokens(free_text_query, clean="No")
|
486 |
|
487 |
-
#print(token_query)
|
488 |
-
|
489 |
# Perform search
|
490 |
print("Searching")
|
491 |
|
@@ -504,7 +634,6 @@ def bm25_search(free_text_query, in_no_search_results, original_data, searched_d
|
|
504 |
|
505 |
# Join scores onto searched data
|
506 |
results_df_out = results_df[['index', 'search_text', 'search_score_abs']].merge(searched_data,left_on="index", right_index=True, how="left", suffixes = ("", "_y")).drop("index_y", axis=1, errors="ignore")
|
507 |
-
|
508 |
|
509 |
|
510 |
# Join on data from duplicate case notes
|
@@ -516,33 +645,27 @@ def bm25_search(free_text_query, in_no_search_results, original_data, searched_d
|
|
516 |
print("Clean is yes")
|
517 |
orig_text_column = text_column.replace("_cleaned", "")
|
518 |
|
519 |
-
#print(orig_text_column)
|
520 |
-
#print(original_data.columns)
|
521 |
-
|
522 |
original_data["original_note_id"] = original_data["copy_of_case_note_id"]
|
523 |
original_data["original_note_id"] = original_data["original_note_id"].combine_first(original_data["note_id"])
|
524 |
|
525 |
results_df_out = results_df_out.merge(original_data[["original_note_id", "note_id", "copy_of_case_note_id", "person_id"]],left_on="note_id", right_on="original_note_id", how="left", suffixes=("_primary", "")) # .drop(orig_text_column, axis = 1)
|
526 |
results_df_out.loc[~results_df_out["copy_of_case_note_id"].isnull(), "search_text"] = ""
|
527 |
results_df_out.loc[~results_df_out["copy_of_case_note_id"].isnull(), text_column] = ""
|
528 |
-
|
529 |
-
#results_df_out = pd.concat([results_df_out, original_data[~original_data["copy_of_case_note_id"].isna()][["copy_of_case_note_id", "person_id"]]])
|
530 |
-
# Replace NaN with an empty string
|
531 |
-
# results_df_out.fillna('', inplace=True)
|
532 |
-
|
533 |
-
|
534 |
|
|
|
|
|
535 |
# Join on additional files
|
536 |
if not in_join_file.empty:
|
537 |
progress(0.5, desc = "Joining on additional data file")
|
538 |
-
join_df = in_join_file
|
539 |
-
|
|
|
540 |
results_df_out[search_df_join_column] = results_df_out[search_df_join_column].astype(str).str.replace("\.0$","", regex=True)
|
541 |
|
542 |
# Duplicates dropped so as not to expand out dataframe
|
543 |
-
|
544 |
|
545 |
-
results_df_out = results_df_out.merge(
|
546 |
|
547 |
# Reorder results by score, and whether there is text
|
548 |
results_df_out = results_df_out.sort_values(['search_score_abs', "search_text"], ascending=False)
|
@@ -559,7 +682,7 @@ def bm25_search(free_text_query, in_no_search_results, original_data, searched_d
|
|
559 |
# Highlight found text and save to file
|
560 |
results_df_out_wb = create_highlighted_excel_wb(results_df_out, free_text_query, "search_text")
|
561 |
results_df_out_wb.save(results_df_name)
|
562 |
-
|
563 |
results_first_text = results_df_out[text_column].iloc[0]
|
564 |
|
565 |
print("Returning results")
|
|
|
8 |
import pandas as pd
|
9 |
from numpy import inf
|
10 |
import gradio as gr
|
11 |
+
from typing import List
|
12 |
|
13 |
from datetime import datetime
|
14 |
|
|
|
166 |
return [documents[i] for i in heapq.nlargest(n, scores.keys(), key=scores.__getitem__)]
|
167 |
|
168 |
|
169 |
+
def get_top_n_with_score(self, query:str, documents:List[str], n=5):
|
170 |
"""
|
171 |
Retrieve the top n documents for the query along with their scores.
|
172 |
|
|
|
230 |
with open(f"{output_folder}{filename}.pkl", "rb") as fsave:
|
231 |
return pickle.load(fsave)
|
232 |
|
233 |
+
def prepare_bm25_input_data(
|
234 |
+
in_file: list,
|
235 |
+
text_column: str,
|
236 |
+
data_state: pd.DataFrame,
|
237 |
+
tokenised_state: list,
|
238 |
+
clean: str = "No",
|
239 |
+
return_intermediate_files: str = "No",
|
240 |
+
progress: gr.Progress = gr.Progress(track_tqdm=True)
|
241 |
+
) -> tuple:
|
242 |
+
"""
|
243 |
+
Prepare BM25 input data by loading, cleaning, and tokenizing the text data.
|
244 |
+
|
245 |
+
Parameters
|
246 |
+
----------
|
247 |
+
in_file: list
|
248 |
+
List of input files to be processed.
|
249 |
+
text_column: str
|
250 |
+
The name of the text column in the data file to search.
|
251 |
+
data_state: pd.DataFrame
|
252 |
+
The current state of the data.
|
253 |
+
tokenised_state: list
|
254 |
+
The current state of the tokenized data.
|
255 |
+
clean: str, optional
|
256 |
+
Whether to clean the text data (default is "No").
|
257 |
+
return_intermediate_files: str, optional
|
258 |
+
Whether to return intermediate processing files (default is "No").
|
259 |
+
progress: gr.Progress, optional
|
260 |
+
Progress tracker for the function (default is gr.Progress(track_tqdm=True)).
|
261 |
+
|
262 |
+
Returns
|
263 |
+
-------
|
264 |
+
tuple
|
265 |
+
A tuple containing the prepared search text list, a message, the updated data state,
|
266 |
+
the tokenized data, the search index, and a dropdown component for the text column.
|
267 |
+
"""
|
268 |
|
|
|
|
|
269 |
ensure_output_folder_exists(output_folder)
|
270 |
|
271 |
if not in_file:
|
272 |
print("No input file found. Please load in at least one file.")
|
273 |
+
return None, "No input file found. Please load in at least one file.", data_state, None, None, gr.Dropdown(allow_custom_value=True, value=text_column, choices=data_state.columns.to_list())
|
274 |
|
275 |
progress(0, desc = "Loading in data")
|
276 |
file_list = [string.name for string in in_file]
|
|
|
280 |
data_file_names = [string for string in file_list if "tokenised" not in string.lower() and "npz" not in string.lower() and "gz" not in string.lower()]
|
281 |
|
282 |
if not data_file_names:
|
283 |
+
return None, "Please load in at least one csv/Excel/parquet data file.", data_state, None, None, gr.Dropdown(allow_custom_value=True, value=text_column, choices=data_state.columns.to_list())
|
284 |
|
285 |
if not text_column:
|
286 |
+
return None, "Please enter a column name to search.", data_state, None, None, gr.Dropdown(allow_custom_value=True, value=text_column, choices=data_state.columns.to_list())
|
287 |
|
288 |
data_file_name = data_file_names[0]
|
289 |
|
290 |
df = data_state #read_file(data_file_name)
|
291 |
+
#data_file_out_name = get_file_path_end_with_ext(data_file_name)
|
292 |
data_file_out_name_no_ext = get_file_path_end(data_file_name)
|
293 |
|
294 |
+
## Load in pre-tokenised prepared_search_text_list if exists
|
295 |
+
#tokenised_df = pd.DataFrame()
|
296 |
|
297 |
+
#tokenised_file_names = [string for string in file_list if "tokenised" in string.lower()]
|
298 |
search_index_file_names = [string for string in file_list if "gz" in string.lower()]
|
299 |
|
300 |
+
# Set all search text to lower case
|
|
|
301 |
df[text_column] = df[text_column].astype(str).str.lower()
|
302 |
|
303 |
if "copy_of_case_note_id" in df.columns:
|
|
|
305 |
df.loc[~df["copy_of_case_note_id"].isna(), text_column] = ""
|
306 |
|
307 |
if search_index_file_names:
|
308 |
+
prepared_search_text_list = list(df[text_column])
|
309 |
message = "Tokenisation skipped - loading search index from file."
|
310 |
print(message)
|
311 |
+
return prepared_search_text_list, message, df, None, None, gr.Dropdown(allow_custom_value=True, value=text_column, choices=data_state.columns.to_list())
|
312 |
|
313 |
|
314 |
if clean == "Yes":
|
|
|
317 |
print("Starting data clean.")
|
318 |
|
319 |
#df = df.drop_duplicates(text_column)
|
320 |
+
prepared_text_as_list = list(df[text_column])
|
321 |
+
prepared_text_as_list = initial_clean(prepared_text_as_list)
|
322 |
|
323 |
# Save to file if you have cleaned the data
|
324 |
+
out_file_name, text_column, df = save_prepared_bm25_data(data_file_name, prepared_text_as_list, df, text_column)
|
325 |
|
326 |
clean_toc = time.perf_counter()
|
327 |
clean_time_out = f"Cleaning the text took {clean_toc - clean_tic:0.1f} seconds."
|
|
|
329 |
|
330 |
else:
|
331 |
# Don't clean or save file to disk
|
332 |
+
prepared_text_as_list = list(df[text_column])
|
333 |
print("No data cleaning performed")
|
334 |
out_file_name = None
|
335 |
|
|
|
337 |
|
338 |
progress(0.4, desc = "Tokenising text")
|
339 |
|
340 |
+
print("Tokenised state:", tokenised_state)
|
341 |
+
|
342 |
if tokenised_state:
|
343 |
+
prepared_search_text_list = tokenised_state.iloc[:,0].tolist()
|
|
|
344 |
print("Tokenised data loaded from file")
|
345 |
+
|
346 |
+
#print("prepared_search_text_list is: ", prepared_search_text_list[0:5])
|
347 |
|
348 |
else:
|
349 |
tokeniser_tic = time.perf_counter()
|
350 |
+
prepared_search_text_list = []
|
351 |
batch_size = 256
|
352 |
+
for doc in tokenizer.pipe(progress.tqdm(prepared_text_as_list, desc = "Tokenising text", unit = "rows"), batch_size=batch_size):
|
353 |
+
prepared_search_text_list.append([token.text for token in doc])
|
354 |
|
355 |
tokeniser_toc = time.perf_counter()
|
356 |
tokenizer_time_out = f"Tokenising the text took {tokeniser_toc - tokeniser_tic:0.1f} seconds."
|
357 |
print(tokenizer_time_out)
|
358 |
+
#print("prepared_search_text_list is: ", prepared_search_text_list[0:5])
|
359 |
|
360 |
+
if len(prepared_text_as_list) >= 20:
|
361 |
message = "Data loaded"
|
362 |
else:
|
363 |
message = "Data loaded. Warning: dataset may be too short to get consistent search results."
|
|
|
369 |
else:
|
370 |
tokenised_data_file_name = output_folder + data_file_out_name_no_ext + "_tokenised.parquet"
|
371 |
|
372 |
+
pd.DataFrame(data={"prepared_search_text_list":prepared_search_text_list}).to_parquet(tokenised_data_file_name)
|
373 |
+
|
374 |
+
return prepared_search_text_list, message, df, out_file_name, tokenised_data_file_name, gr.Dropdown(allow_custom_value=True, value=text_column, choices=data_state.columns.to_list()) # prepared_text_as_list,
|
375 |
+
|
376 |
+
return prepared_search_text_list, message, df, out_file_name, None, gr.Dropdown(allow_custom_value=True, value=text_column, choices=data_state.columns.to_list()) # prepared_text_as_list,
|
377 |
|
378 |
+
def save_prepared_bm25_data(in_file_name: str, prepared_text_list: list, in_df: pd.DataFrame, in_bm25_column: str, progress: gr.Progress = gr.Progress(track_tqdm=True)) -> tuple:
|
379 |
+
"""
|
380 |
+
Save the prepared BM25 data to a file.
|
381 |
+
|
382 |
+
This function ensures the output folder exists, checks if the length of the prepared text list matches the input dataframe,
|
383 |
+
and saves the prepared data to a file in the specified format. The original column in the input dataframe is dropped to reduce file size.
|
384 |
|
385 |
+
Parameters:
|
386 |
+
- in_file_name (str): The name of the input file.
|
387 |
+
- prepared_text_list (list): The list of prepared text.
|
388 |
+
- in_df (pd.DataFrame): The input dataframe.
|
389 |
+
- in_bm25_column (str): The name of the column to be processed.
|
390 |
+
- progress (gr.Progress, optional): The progress tracker for the operation.
|
391 |
|
392 |
+
Returns:
|
393 |
+
- tuple: A tuple containing the file name, new text column name, and the prepared dataframe.
|
394 |
+
"""
|
395 |
|
396 |
ensure_output_folder_exists(output_folder)
|
397 |
|
|
|
419 |
|
420 |
return file_name, new_text_column, prepared_df
|
421 |
|
422 |
+
def prepare_bm25(
|
423 |
+
prepared_search_text_list: List[str],
|
424 |
+
in_file: List[gr.File],
|
425 |
+
text_column: str,
|
426 |
+
search_index: BM25,
|
427 |
+
clean: str,
|
428 |
+
return_intermediate_files: str,
|
429 |
+
k1: float = 1.5,
|
430 |
+
b: float = 0.75,
|
431 |
+
alpha: float = -5,
|
432 |
+
progress: gr.Progress = gr.Progress(track_tqdm=True)
|
433 |
+
) -> tuple:
|
434 |
+
"""
|
435 |
+
Prepare the BM25 search index.
|
436 |
+
|
437 |
+
This function prepares the BM25 search index from the provided text list and input file. It ensures the necessary
|
438 |
+
files and columns are present, processes the data, and optionally saves intermediate files.
|
439 |
+
|
440 |
+
Parameters:
|
441 |
+
- prepared_search_text_list (List[str]): The list of prepared search text.
|
442 |
+
- in_file (List[gr.File]): The list of input files.
|
443 |
+
- text_column (str): The name of the column to search.
|
444 |
+
- search_index (BM25): The BM25 search index.
|
445 |
+
- clean (str): Indicates whether to clean the data.
|
446 |
+
- return_intermediate_files (str): Indicates whether to return intermediate files.
|
447 |
+
- k1 (float, optional): The k1 parameter for BM25. Default is 1.5.
|
448 |
+
- b (float, optional): The b parameter for BM25. Default is 0.75.
|
449 |
+
- alpha (float, optional): The alpha parameter for BM25. Default is -5.
|
450 |
+
- progress (gr.Progress, optional): The progress tracker for the operation.
|
451 |
+
|
452 |
+
Returns:
|
453 |
+
- tuple: A tuple containing the output message, BM25 search index, and other relevant information.
|
454 |
+
"""
|
455 |
|
456 |
if not in_file:
|
457 |
out_message ="No input file found. Please load in at least one file."
|
458 |
print(out_message)
|
459 |
+
return out_message, None, None
|
460 |
|
461 |
+
if not prepared_search_text_list:
|
462 |
out_message = "No data file found. Please load in at least one csv/Excel/Parquet file."
|
463 |
print(out_message)
|
464 |
+
return out_message, None, None, None
|
465 |
|
466 |
if not text_column:
|
467 |
out_message = "Please enter a column name to search."
|
468 |
print(out_message)
|
469 |
+
return out_message, None, None, None
|
470 |
|
471 |
file_list = [string.name for string in in_file]
|
472 |
|
|
|
476 |
data_file_names = [string for string in file_list if "tokenised" not in string.lower() and "npz" not in string.lower() and "gz" not in string.lower()]
|
477 |
|
478 |
if not data_file_names:
|
479 |
+
return "Please load in at least one csv/Excel/parquet data file.", None, None, None
|
480 |
|
481 |
data_file_name = data_file_names[0]
|
482 |
data_file_out_name = get_file_path_end_with_ext(data_file_name)
|
483 |
data_file_name_no_ext = get_file_path_end(data_file_name)
|
484 |
|
|
|
|
|
|
|
485 |
progress(0.6, desc = "Preparing search index")
|
486 |
|
|
|
487 |
if search_index:
|
488 |
+
bm25 = search_index
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
489 |
else:
|
490 |
+
print("Preparing BM25 search corpus")
|
491 |
|
492 |
+
bm25 = BM25(prepared_search_text_list, k1=k1, b=b, alpha=alpha)
|
493 |
|
494 |
+
#global bm25
|
495 |
+
#bm25 = bm25_load
|
496 |
|
497 |
if return_intermediate_files == "Yes":
|
498 |
print("Saving search index file")
|
|
|
517 |
|
518 |
print(message)
|
519 |
|
520 |
+
return message, None, bm25, prepared_search_text_list
|
521 |
|
522 |
def convert_bm25_query_to_tokens(free_text_query, clean="No"):
|
523 |
'''
|
|
|
540 |
|
541 |
return out_query
|
542 |
|
543 |
+
def bm25_search(
|
544 |
+
free_text_query: str,
|
545 |
+
in_no_search_results: int,
|
546 |
+
original_data: pd.DataFrame,
|
547 |
+
searched_data: pd.DataFrame,
|
548 |
+
text_column: str,
|
549 |
+
in_join_file: str,
|
550 |
+
clean: str,
|
551 |
+
bm25: BM25,
|
552 |
+
prepared_search_text_list_state: list,
|
553 |
+
in_join_column: str = "",
|
554 |
+
search_df_join_column: str = "",
|
555 |
+
k1: float = 1.5,
|
556 |
+
b: float = 0.75,
|
557 |
+
alpha: float = -5,
|
558 |
+
progress: gr.Progress = gr.Progress(track_tqdm=True)
|
559 |
+
) -> tuple:
|
560 |
+
"""
|
561 |
+
Perform a BM25 search on the provided text data.
|
562 |
+
|
563 |
+
Parameters
|
564 |
+
----------
|
565 |
+
free_text_query : str
|
566 |
+
The query text to search for.
|
567 |
+
in_no_search_results : int
|
568 |
+
The number of search results to return.
|
569 |
+
original_data : pd.DataFrame
|
570 |
+
The original data containing the text to be searched.
|
571 |
+
searched_data : pd.DataFrame
|
572 |
+
The data that has been prepared for searching.
|
573 |
+
text_column : str
|
574 |
+
The name of the column in the data to search.
|
575 |
+
in_join_file : str
|
576 |
+
The file to join the search results with.
|
577 |
+
clean : str
|
578 |
+
Whether to clean the text data.
|
579 |
+
bm25 : BM25
|
580 |
+
The BM25 object used for searching.
|
581 |
+
prepared_search_text_list_state : list
|
582 |
+
The state of the prepared search text list.
|
583 |
+
in_join_column : str, optional
|
584 |
+
The column to join on in the input file (default is "").
|
585 |
+
search_df_join_column : str, optional
|
586 |
+
The column to join on in the search dataframe (default is "").
|
587 |
+
k1 : float, optional
|
588 |
+
The k1 parameter for BM25 (default is 1.5).
|
589 |
+
b : float, optional
|
590 |
+
The b parameter for BM25 (default is 0.75).
|
591 |
+
alpha : float, optional
|
592 |
+
The alpha parameter for BM25 (default is -5).
|
593 |
+
progress : gr.Progress, optional
|
594 |
+
Progress tracker for the function (default is gr.Progress(track_tqdm=True)).
|
595 |
+
|
596 |
+
Returns
|
597 |
+
-------
|
598 |
+
tuple
|
599 |
+
A tuple containing a message, the search results file name (if any), the BM25 object, and the prepared search text list.
|
600 |
+
"""
|
601 |
|
602 |
progress(0, desc = "Conducting keyword search")
|
603 |
+
|
604 |
+
print("in_join_file at start of bm25_search:", in_join_file)
|
605 |
+
|
606 |
+
if not bm25:
|
607 |
+
print("Preparing BM25 search corpus")
|
608 |
+
|
609 |
+
bm25 = BM25(prepared_search_text_list_state, k1=k1, b=b, alpha=alpha)
|
610 |
+
|
611 |
+
# print("bm25:", bm25)
|
612 |
|
613 |
# Prepare query
|
614 |
if (clean == "Yes") | (text_column.endswith("_cleaned")):
|
|
|
616 |
else:
|
617 |
token_query = convert_bm25_query_to_tokens(free_text_query, clean="No")
|
618 |
|
|
|
|
|
619 |
# Perform search
|
620 |
print("Searching")
|
621 |
|
|
|
634 |
|
635 |
# Join scores onto searched data
|
636 |
results_df_out = results_df[['index', 'search_text', 'search_score_abs']].merge(searched_data,left_on="index", right_index=True, how="left", suffixes = ("", "_y")).drop("index_y", axis=1, errors="ignore")
|
|
|
637 |
|
638 |
|
639 |
# Join on data from duplicate case notes
|
|
|
645 |
print("Clean is yes")
|
646 |
orig_text_column = text_column.replace("_cleaned", "")
|
647 |
|
|
|
|
|
|
|
648 |
original_data["original_note_id"] = original_data["copy_of_case_note_id"]
|
649 |
original_data["original_note_id"] = original_data["original_note_id"].combine_first(original_data["note_id"])
|
650 |
|
651 |
results_df_out = results_df_out.merge(original_data[["original_note_id", "note_id", "copy_of_case_note_id", "person_id"]],left_on="note_id", right_on="original_note_id", how="left", suffixes=("_primary", "")) # .drop(orig_text_column, axis = 1)
|
652 |
results_df_out.loc[~results_df_out["copy_of_case_note_id"].isnull(), "search_text"] = ""
|
653 |
results_df_out.loc[~results_df_out["copy_of_case_note_id"].isnull(), text_column] = ""
|
|
|
|
|
|
|
|
|
|
|
|
|
654 |
|
655 |
+
print("in_join_file:", in_join_file)
|
656 |
+
|
657 |
# Join on additional files
|
658 |
if not in_join_file.empty:
|
659 |
progress(0.5, desc = "Joining on additional data file")
|
660 |
+
#join_df = in_join_file
|
661 |
+
# Prepare join columns as string and remove .0 at end of stringified numbers
|
662 |
+
in_join_file[in_join_column] = in_join_file[in_join_column].astype(str).str.replace("\.0$","", regex=True)
|
663 |
results_df_out[search_df_join_column] = results_df_out[search_df_join_column].astype(str).str.replace("\.0$","", regex=True)
|
664 |
|
665 |
# Duplicates dropped so as not to expand out dataframe
|
666 |
+
in_join_file = in_join_file.drop_duplicates(in_join_column)
|
667 |
|
668 |
+
results_df_out = results_df_out.merge(in_join_file,left_on=search_df_join_column, right_on=in_join_column, how="left", suffixes=('','_y'))#.drop(in_join_column, axis=1)
|
669 |
|
670 |
# Reorder results by score, and whether there is text
|
671 |
results_df_out = results_df_out.sort_values(['search_score_abs', "search_text"], ascending=False)
|
|
|
682 |
# Highlight found text and save to file
|
683 |
results_df_out_wb = create_highlighted_excel_wb(results_df_out, free_text_query, "search_text")
|
684 |
results_df_out_wb.save(results_df_name)
|
685 |
+
|
686 |
results_first_text = results_df_out[text_column].iloc[0]
|
687 |
|
688 |
print("Returning results")
|
search_funcs/helper_functions.py
CHANGED
@@ -9,6 +9,8 @@ import gzip
|
|
9 |
import pickle
|
10 |
import numpy as np
|
11 |
|
|
|
|
|
12 |
# Openpyxl functions for output
|
13 |
from openpyxl import Workbook
|
14 |
from openpyxl.cell.text import InlineFont
|
@@ -175,15 +177,15 @@ def read_file(filename):
|
|
175 |
|
176 |
return file
|
177 |
|
178 |
-
def initial_data_load(in_file):
|
179 |
'''
|
180 |
-
When file is loaded, update the column dropdown choices
|
181 |
'''
|
182 |
new_choices = []
|
183 |
concat_choices = []
|
184 |
index_load = None
|
185 |
embed_load = np.array([])
|
186 |
-
tokenised_load =[]
|
187 |
out_message = ""
|
188 |
current_source = ""
|
189 |
df = pd.DataFrame()
|
@@ -257,7 +259,7 @@ def initial_data_load(in_file):
|
|
257 |
|
258 |
return gr.Dropdown(choices=concat_choices), gr.Dropdown(choices=concat_choices), df, df, index_load, embed_load, tokenised_load, out_message, current_source
|
259 |
|
260 |
-
def put_columns_in_join_df(in_file):
|
261 |
'''
|
262 |
When file is loaded, update the column dropdown choices
|
263 |
'''
|
@@ -354,7 +356,20 @@ def highlight_found_text(search_text: str, full_text: str) -> str:
|
|
354 |
|
355 |
return "".join(pos_tokens), combined_positions
|
356 |
|
357 |
-
def create_rich_text_cell_from_positions(full_text, combined_positions):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
358 |
# Construct pos_tokens
|
359 |
red = InlineFont(color='00FF0000')
|
360 |
rich_text_cell = CellRichText()
|
@@ -369,7 +384,21 @@ def create_rich_text_cell_from_positions(full_text, combined_positions):
|
|
369 |
|
370 |
return rich_text_cell
|
371 |
|
372 |
-
def create_highlighted_excel_wb(df, search_text, column_to_highlight):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
373 |
|
374 |
# Create a new Excel workbook
|
375 |
wb = Workbook()
|
|
|
9 |
import pickle
|
10 |
import numpy as np
|
11 |
|
12 |
+
from typing import List
|
13 |
+
|
14 |
# Openpyxl functions for output
|
15 |
from openpyxl import Workbook
|
16 |
from openpyxl.cell.text import InlineFont
|
|
|
177 |
|
178 |
return file
|
179 |
|
180 |
+
def initial_data_load(in_file:List[str]):
|
181 |
'''
|
182 |
+
When file is loaded, update the column dropdown choices and relevant state variables
|
183 |
'''
|
184 |
new_choices = []
|
185 |
concat_choices = []
|
186 |
index_load = None
|
187 |
embed_load = np.array([])
|
188 |
+
tokenised_load = []
|
189 |
out_message = ""
|
190 |
current_source = ""
|
191 |
df = pd.DataFrame()
|
|
|
259 |
|
260 |
return gr.Dropdown(choices=concat_choices), gr.Dropdown(choices=concat_choices), df, df, index_load, embed_load, tokenised_load, out_message, current_source
|
261 |
|
262 |
+
def put_columns_in_join_df(in_file:str):
|
263 |
'''
|
264 |
When file is loaded, update the column dropdown choices
|
265 |
'''
|
|
|
356 |
|
357 |
return "".join(pos_tokens), combined_positions
|
358 |
|
359 |
+
def create_rich_text_cell_from_positions(full_text: str, combined_positions: list[tuple[int, int]]) -> CellRichText:
|
360 |
+
"""
|
361 |
+
Create a rich text cell with highlighted positions.
|
362 |
+
|
363 |
+
This function takes the full text and a list of combined positions, and creates a rich text cell
|
364 |
+
with the specified positions highlighted in red.
|
365 |
+
|
366 |
+
Parameters:
|
367 |
+
full_text (str): The full text to be processed.
|
368 |
+
combined_positions (list[tuple[int, int]]): A list of tuples representing the start and end positions to be highlighted.
|
369 |
+
|
370 |
+
Returns:
|
371 |
+
CellRichText: The created rich text cell with highlighted positions.
|
372 |
+
"""
|
373 |
# Construct pos_tokens
|
374 |
red = InlineFont(color='00FF0000')
|
375 |
rich_text_cell = CellRichText()
|
|
|
384 |
|
385 |
return rich_text_cell
|
386 |
|
387 |
+
def create_highlighted_excel_wb(df: pd.DataFrame, search_text: str, column_to_highlight: str) -> Workbook:
|
388 |
+
"""
|
389 |
+
Create a new Excel workbook with highlighted search text.
|
390 |
+
|
391 |
+
This function takes a DataFrame, a search text, and a column name to highlight. It creates a new Excel workbook,
|
392 |
+
highlights the occurrences of the search text in the specified column, and returns the workbook.
|
393 |
+
|
394 |
+
Parameters:
|
395 |
+
df (pd.DataFrame): The DataFrame containing the data to be written to the Excel workbook.
|
396 |
+
search_text (str): The text to search for and highlight in the specified column.
|
397 |
+
column_to_highlight (str): The name of the column in which to highlight the search text.
|
398 |
+
|
399 |
+
Returns:
|
400 |
+
Workbook: The created Excel workbook with highlighted search text.
|
401 |
+
"""
|
402 |
|
403 |
# Create a new Excel workbook
|
404 |
wb = Workbook()
|
search_funcs/semantic_functions.py
CHANGED
@@ -5,11 +5,10 @@ from typing import Type
|
|
5 |
import gradio as gr
|
6 |
import numpy as np
|
7 |
from datetime import datetime
|
8 |
-
|
9 |
-
from
|
10 |
-
#import torch
|
11 |
-
from torch import cuda, backends#, tensor, mm, utils
|
12 |
from sentence_transformers import SentenceTransformer
|
|
|
13 |
|
14 |
today_rev = datetime.now().strftime("%Y%m%d")
|
15 |
|
@@ -25,22 +24,6 @@ else:
|
|
25 |
|
26 |
print("Device used is: ", torch_device)
|
27 |
|
28 |
-
from search_funcs.helper_functions import create_highlighted_excel_wb, ensure_output_folder_exists, output_folder
|
29 |
-
|
30 |
-
PandasDataFrame = Type[pd.DataFrame]
|
31 |
-
|
32 |
-
# Load embeddings - Jina - deprecated
|
33 |
-
# Pinning a Jina revision for security purposes: https://www.baseten.co/blog/pinning-ml-model-revisions-for-compatibility-and-security/
|
34 |
-
# Save Jina model locally as described here: https://huggingface.co/jinaai/jina-embeddings-v2-base-en/discussions/29
|
35 |
-
# embeddings_name = "jinaai/jina-embeddings-v2-small-en"
|
36 |
-
# local_embeddings_location = "model/jina/"
|
37 |
-
# revision_choice = "b811f03af3d4d7ea72a7c25c802b21fc675a5d99"
|
38 |
-
|
39 |
-
# try:
|
40 |
-
# embeddings_model = AutoModel.from_pretrained(local_embeddings_location, revision = revision_choice, trust_remote_code=True,local_files_only=True, device_map="auto")
|
41 |
-
# except:
|
42 |
-
# embeddings_model = AutoModel.from_pretrained(embeddings_name, revision = revision_choice, trust_remote_code=True, device_map="auto")
|
43 |
-
|
44 |
# Load embeddings
|
45 |
embeddings_name = "BAAI/bge-small-en-v1.5"
|
46 |
|
@@ -65,32 +48,53 @@ else:
|
|
65 |
embeddings_model = SentenceTransformer(embeddings_name)
|
66 |
print("Could not find local model installation. Downloading from Huggingface")
|
67 |
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
|
73 |
ensure_output_folder_exists(output_folder)
|
74 |
|
75 |
if not in_file:
|
76 |
out_message = "No input file found. Please load in at least one file."
|
77 |
print(out_message)
|
78 |
-
return out_message, None, None, output_file_state
|
79 |
-
|
80 |
|
81 |
progress(0.6, desc = "Loading/creating embeddings")
|
82 |
|
83 |
print(f"> Total split documents: {len(docs_out)}")
|
84 |
|
85 |
-
#print(docs_out)
|
86 |
-
|
87 |
page_contents = [doc.page_content for doc in docs_out]
|
88 |
|
89 |
## Load in pre-embedded file if exists
|
90 |
file_list = [string.name for string in in_file]
|
91 |
|
92 |
-
#print(file_list)
|
93 |
-
|
94 |
embeddings_file_names = [string for string in file_list if "embedding" in string.lower()]
|
95 |
data_file_names = [string for string in file_list if "tokenised" not in string.lower() and "npz" not in string.lower()]# and "gz" not in string.lower()]
|
96 |
data_file_name = data_file_names[0]
|
@@ -98,22 +102,12 @@ def docs_to_bge_embed_np_array(docs_out, in_file, embeddings_state, output_file_
|
|
98 |
|
99 |
out_message = "Document processing complete. Ready to search."
|
100 |
|
101 |
-
# print("embeddings loaded: ", embeddings_out)
|
102 |
|
103 |
if embeddings_state.size == 0:
|
104 |
tic = time.perf_counter()
|
105 |
print("Starting to embed documents.")
|
106 |
-
#embeddings_list = []
|
107 |
-
#for page in progress.tqdm(page_contents, desc = "Preparing search index", unit = "rows"):
|
108 |
-
# embeddings_list.append(embeddings.encode(sentences=page, max_length=1024).tolist())
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
#embeddings_out = calc_bge_norm_embeddings(page_contents, embeddings_model, tokenizer)
|
113 |
|
114 |
embeddings_out = embeddings_model.encode(sentences=page_contents, show_progress_bar = True, batch_size = 32, normalize_embeddings=True) # For BGE
|
115 |
-
#embeddings_list = embeddings.encode(sentences=page_contents, normalize_embeddings=True).tolist() # For BGE embeddings
|
116 |
-
#embeddings_list = embeddings.encode(sentences=page_contents).tolist() # For minilm
|
117 |
|
118 |
toc = time.perf_counter()
|
119 |
time_out = f"The embedding took {toc - tic:0.1f} seconds"
|
@@ -147,31 +141,43 @@ def docs_to_bge_embed_np_array(docs_out, in_file, embeddings_state, output_file_
|
|
147 |
|
148 |
return out_message, embeddings_out, output_file_state, output_file_state
|
149 |
|
150 |
-
def process_data_from_scores_df(
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
168 |
docs_scores = df_docs["distances"] #.astype(float)
|
169 |
|
170 |
# Only keep sources that are sufficiently relevant (i.e. similarity search score below threshold below)
|
171 |
score_more_limit = df_docs.loc[docs_scores > vec_score_cut_off, :]
|
172 |
-
#docs_keep = create_docs_keep_from_df(score_more_limit) #list(compress(docs, score_more_limit))
|
173 |
-
|
174 |
-
#print(docs_keep)
|
175 |
|
176 |
if score_more_limit.empty:
|
177 |
return pd.DataFrame()
|
@@ -179,26 +185,17 @@ def process_data_from_scores_df(df_docs, in_join_file, out_passages, vec_score_c
|
|
179 |
# Only keep sources that are at least 100 characters long
|
180 |
docs_len = score_more_limit["documents"].str.len() >= 100
|
181 |
|
182 |
-
#print(docs_len)
|
183 |
-
|
184 |
length_more_limit = score_more_limit.loc[docs_len == True, :] #pd.Series(docs_len) >= 100
|
185 |
-
#docs_keep = create_docs_keep_from_df(length_more_limit) #list(compress(docs_keep, length_more_limit))
|
186 |
-
|
187 |
-
#print(length_more_limit)
|
188 |
|
189 |
if length_more_limit.empty:
|
190 |
return pd.DataFrame()
|
191 |
|
192 |
length_more_limit['ids'] = length_more_limit['ids'].astype(int)
|
193 |
|
194 |
-
#length_more_limit.to_csv("length_more_limit.csv", index = None)
|
195 |
|
196 |
# Explode the 'metadatas' dictionary into separate columns
|
197 |
df_metadata_expanded = length_more_limit['metadatas'].apply(pd.Series)
|
198 |
|
199 |
-
#print(length_more_limit)
|
200 |
-
#print(df_metadata_expanded)
|
201 |
-
|
202 |
# Concatenate the original DataFrame with the expanded metadata DataFrame
|
203 |
results_df_out = pd.concat([length_more_limit.drop('metadatas', axis=1), df_metadata_expanded], axis=1)
|
204 |
|
@@ -208,9 +205,6 @@ def process_data_from_scores_df(df_docs, in_join_file, out_passages, vec_score_c
|
|
208 |
results_df_out['distances'] = round(results_df_out['distances'].astype(float), 3)
|
209 |
|
210 |
|
211 |
-
# Join back to original df
|
212 |
-
# results_df_out = orig_df.merge(length_more_limit[['ids', 'distances']], left_index = True, right_on = "ids", how="inner").sort_values("distances")
|
213 |
-
|
214 |
# Join on additional files
|
215 |
if not in_join_file.empty:
|
216 |
progress(0.5, desc = "Joining on additional data file")
|
@@ -227,68 +221,73 @@ def process_data_from_scores_df(df_docs, in_join_file, out_passages, vec_score_c
|
|
227 |
|
228 |
return results_df_out
|
229 |
|
230 |
-
def
|
231 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
232 |
|
233 |
-
# print("vectorstore loaded: ", vectorstore)
|
234 |
progress(0, desc = "Conducting semantic search")
|
235 |
|
236 |
ensure_output_folder_exists(output_folder)
|
237 |
|
238 |
print("Searching")
|
239 |
|
240 |
-
# Convert it to a PyTorch tensor and transfer to GPU
|
241 |
-
#vectorstore_tensor = tensor(vectorstore).to(device)
|
242 |
-
|
243 |
# Load the sentence transformer model and move it to GPU
|
244 |
-
|
245 |
|
246 |
# Encode the query using the sentence transformer and convert to a PyTorch tensor
|
247 |
-
query =
|
248 |
-
|
249 |
-
# query = calc_bge_norm_embeddings(query_str, embeddings_model=embeddings_model, tokenizer=tokenizer)
|
250 |
-
|
251 |
-
#query_tensor = tensor(query).to(device)
|
252 |
-
|
253 |
-
# if query_tensor.dim() == 1:
|
254 |
-
# query_tensor = query_tensor.unsqueeze(0) # Reshape to 2D with one row
|
255 |
|
256 |
# Sentence transformers method, not used:
|
257 |
-
cosine_similarities = query @
|
258 |
-
#cosine_similarities = util.cos_sim(query_tensor, vectorstore_tensor)[0]
|
259 |
-
#top_results = torch.topk(cos_scores, k=top_k)
|
260 |
-
|
261 |
-
|
262 |
-
# Normalize the query tensor and vectorstore tensor
|
263 |
-
#query_norm = query_tensor / query_tensor.norm(dim=1, keepdim=True)
|
264 |
-
#vectorstore_norm = vectorstore_tensor / vectorstore_tensor.norm(dim=1, keepdim=True)
|
265 |
-
|
266 |
-
# Calculate cosine similarities (batch processing)
|
267 |
-
#cosine_similarities = mm(query_norm, vectorstore_norm.T)
|
268 |
-
#cosine_similarities = mm(query_tensor, vectorstore_tensor.T)
|
269 |
|
270 |
# Flatten the tensor to a 1D array
|
271 |
cosine_similarities = cosine_similarities.flatten()
|
272 |
|
273 |
-
# Convert to a NumPy array if it's still a PyTorch tensor
|
274 |
-
#cosine_similarities = cosine_similarities.cpu().numpy()
|
275 |
-
|
276 |
# Create a Pandas Series
|
277 |
cosine_similarities_series = pd.Series(cosine_similarities)
|
278 |
|
279 |
-
# Pull out relevent info from
|
280 |
-
page_contents = [doc.page_content for doc in
|
281 |
-
page_meta = [doc.metadata for doc in
|
282 |
ids_range = range(0,len(page_contents))
|
283 |
ids = [str(element) for element in ids_range]
|
284 |
|
285 |
-
|
286 |
"documents": page_contents,
|
287 |
"metadatas":page_meta,
|
288 |
"distances":cosine_similarities_series}).sort_values("distances", ascending=False).iloc[0:k_val,:]
|
289 |
|
290 |
|
291 |
-
results_df_out = process_data_from_scores_df(
|
292 |
|
293 |
print("Search complete")
|
294 |
|
@@ -312,291 +311,4 @@ def bge_simple_retrieval(query_str:str, vectorstore, docs, orig_df_col:str, k_va
|
|
312 |
|
313 |
print("Returning results")
|
314 |
|
315 |
-
return results_first_text, results_df_name
|
316 |
-
|
317 |
-
|
318 |
-
def docs_to_jina_embed_np_array_deprecated(docs_out, in_file, embeddings_state, return_intermediate_files = "No", embeddings_super_compress = "No", embeddings = embeddings_model, progress=gr.Progress(track_tqdm=True)):
|
319 |
-
'''
|
320 |
-
Takes a Langchain document class and saves it into a Chroma sqlite file.
|
321 |
-
'''
|
322 |
-
if not in_file:
|
323 |
-
out_message = "No input file found. Please load in at least one file."
|
324 |
-
print(out_message)
|
325 |
-
return out_message, None, None
|
326 |
-
|
327 |
-
|
328 |
-
progress(0.6, desc = "Loading/creating embeddings")
|
329 |
-
|
330 |
-
print(f"> Total split documents: {len(docs_out)}")
|
331 |
-
|
332 |
-
#print(docs_out)
|
333 |
-
|
334 |
-
page_contents = [doc.page_content for doc in docs_out]
|
335 |
-
|
336 |
-
## Load in pre-embedded file if exists
|
337 |
-
file_list = [string.name for string in in_file]
|
338 |
-
|
339 |
-
#print(file_list)
|
340 |
-
|
341 |
-
embeddings_file_names = [string for string in file_list if "embedding" in string.lower()]
|
342 |
-
data_file_names = [string for string in file_list if "tokenised" not in string.lower() and "npz" not in string.lower()]# and "gz" not in string.lower()]
|
343 |
-
data_file_name = data_file_names[0]
|
344 |
-
data_file_name_no_ext = get_file_path_end(data_file_name)
|
345 |
-
|
346 |
-
out_message = "Document processing complete. Ready to search."
|
347 |
-
|
348 |
-
# print("embeddings loaded: ", embeddings_out)
|
349 |
-
|
350 |
-
if embeddings_state.size == 0:
|
351 |
-
tic = time.perf_counter()
|
352 |
-
print("Starting to embed documents.")
|
353 |
-
#embeddings_list = []
|
354 |
-
#for page in progress.tqdm(page_contents, desc = "Preparing search index", unit = "rows"):
|
355 |
-
# embeddings_list.append(embeddings.encode(sentences=page, max_length=1024).tolist())
|
356 |
-
|
357 |
-
embeddings_out = embeddings.encode(sentences=page_contents, max_length=1024, show_progress_bar = True, batch_size = 32) # For Jina embeddings
|
358 |
-
#embeddings_list = embeddings.encode(sentences=page_contents, normalize_embeddings=True).tolist() # For BGE embeddings
|
359 |
-
#embeddings_list = embeddings.encode(sentences=page_contents).tolist() # For minilm
|
360 |
-
|
361 |
-
toc = time.perf_counter()
|
362 |
-
time_out = f"The embedding took {toc - tic:0.1f} seconds"
|
363 |
-
print(time_out)
|
364 |
-
|
365 |
-
# If you want to save your files for next time
|
366 |
-
if return_intermediate_files == "Yes":
|
367 |
-
progress(0.9, desc = "Saving embeddings to file")
|
368 |
-
if embeddings_super_compress == "No":
|
369 |
-
semantic_search_file_name = data_file_name_no_ext + '_' + 'embeddings.npz'
|
370 |
-
np.savez_compressed(semantic_search_file_name, embeddings_out)
|
371 |
-
else:
|
372 |
-
semantic_search_file_name = data_file_name_no_ext + '_' + 'embedding_compress.npz'
|
373 |
-
embeddings_out_round = np.round(embeddings_out, 3)
|
374 |
-
embeddings_out_round *= 100 # Rounding not currently used
|
375 |
-
np.savez_compressed(semantic_search_file_name, embeddings_out_round)
|
376 |
-
|
377 |
-
return out_message, embeddings_out, semantic_search_file_name
|
378 |
-
|
379 |
-
return out_message, embeddings_out, None
|
380 |
-
else:
|
381 |
-
# Just return existing embeddings if already exist
|
382 |
-
embeddings_out = embeddings_state
|
383 |
-
|
384 |
-
print(out_message)
|
385 |
-
|
386 |
-
return out_message, embeddings_out, None#, None
|
387 |
-
|
388 |
-
def jina_simple_retrieval_deprecated(query_str:str, vectorstore, docs, orig_df_col:str, k_val:int, out_passages:int,
|
389 |
-
vec_score_cut_off:float, vec_weight:float, in_join_file, in_join_column = None, search_df_join_column = None, device = torch_device, embeddings = embeddings_model, progress=gr.Progress(track_tqdm=True)): # ,vectorstore, embeddings
|
390 |
-
|
391 |
-
# print("vectorstore loaded: ", vectorstore)
|
392 |
-
progress(0, desc = "Conducting semantic search")
|
393 |
-
|
394 |
-
print("Searching")
|
395 |
-
|
396 |
-
# Convert it to a PyTorch tensor and transfer to GPU
|
397 |
-
vectorstore_tensor = tensor(vectorstore).to(device)
|
398 |
-
|
399 |
-
# Load the sentence transformer model and move it to GPU
|
400 |
-
embeddings = embeddings.to(device)
|
401 |
-
|
402 |
-
# Encode the query using the sentence transformer and convert to a PyTorch tensor
|
403 |
-
query = embeddings.encode(query_str)
|
404 |
-
query_tensor = tensor(query).to(device)
|
405 |
-
|
406 |
-
if query_tensor.dim() == 1:
|
407 |
-
query_tensor = query_tensor.unsqueeze(0) # Reshape to 2D with one row
|
408 |
-
|
409 |
-
# Normalize the query tensor and vectorstore tensor
|
410 |
-
query_norm = query_tensor / query_tensor.norm(dim=1, keepdim=True)
|
411 |
-
vectorstore_norm = vectorstore_tensor / vectorstore_tensor.norm(dim=1, keepdim=True)
|
412 |
-
|
413 |
-
# Calculate cosine similarities (batch processing)
|
414 |
-
cosine_similarities = mm(query_norm, vectorstore_norm.T)
|
415 |
-
|
416 |
-
# Flatten the tensor to a 1D array
|
417 |
-
cosine_similarities = cosine_similarities.flatten()
|
418 |
-
|
419 |
-
# Convert to a NumPy array if it's still a PyTorch tensor
|
420 |
-
cosine_similarities = cosine_similarities.cpu().numpy()
|
421 |
-
|
422 |
-
# Create a Pandas Series
|
423 |
-
cosine_similarities_series = pd.Series(cosine_similarities)
|
424 |
-
|
425 |
-
# Pull out relevent info from docs
|
426 |
-
page_contents = [doc.page_content for doc in docs]
|
427 |
-
page_meta = [doc.metadata for doc in docs]
|
428 |
-
ids_range = range(0,len(page_contents))
|
429 |
-
ids = [str(element) for element in ids_range]
|
430 |
-
|
431 |
-
df_docs = pd.DataFrame(data={"ids": ids,
|
432 |
-
"documents": page_contents,
|
433 |
-
"metadatas":page_meta,
|
434 |
-
"distances":cosine_similarities_series}).sort_values("distances", ascending=False).iloc[0:k_val,:]
|
435 |
-
|
436 |
-
|
437 |
-
results_df_out = process_data_from_scores_df(df_docs, in_join_file, out_passages, vec_score_cut_off, vec_weight, orig_df_col, in_join_column, search_df_join_column)
|
438 |
-
|
439 |
-
print("Search complete")
|
440 |
-
|
441 |
-
# If nothing found, return error message
|
442 |
-
if results_df_out.empty:
|
443 |
-
return 'No result found!', None
|
444 |
-
|
445 |
-
query_str_file = query_str.replace(" ", "_")
|
446 |
-
|
447 |
-
results_df_name = "semantic_search_result_" + today_rev + "_" + query_str_file + ".xlsx"
|
448 |
-
|
449 |
-
print("Saving search output to file")
|
450 |
-
progress(0.7, desc = "Saving search output to file")
|
451 |
-
|
452 |
-
results_df_out.to_excel(results_df_name, index= None)
|
453 |
-
results_first_text = results_df_out.iloc[0, 1]
|
454 |
-
|
455 |
-
print("Returning results")
|
456 |
-
|
457 |
-
return results_first_text, results_df_name
|
458 |
-
|
459 |
-
# Deprecated Chroma functions - kept just in case needed in future.
|
460 |
-
# Chroma support is currently deprecated
|
461 |
-
# Import Chroma and instantiate a client. The default Chroma client is ephemeral, meaning it will not save to disk.
|
462 |
-
#import chromadb
|
463 |
-
#from chromadb.config import Settings
|
464 |
-
#from typing_extensions import Protocol
|
465 |
-
#from chromadb import Documents, EmbeddingFunction, Embeddings
|
466 |
-
|
467 |
-
# Remove Chroma database file. If it exists as it can cause issues
|
468 |
-
#chromadb_file = "chroma.sqlite3"
|
469 |
-
|
470 |
-
#if os.path.isfile(chromadb_file):
|
471 |
-
# os.remove(chromadb_file)
|
472 |
-
|
473 |
-
|
474 |
-
def docs_to_chroma_save_deprecated(docs_out, embeddings = embeddings_model, progress=gr.Progress()):
|
475 |
-
'''
|
476 |
-
Takes a Langchain document class and saves it into a Chroma sqlite file. Not currently used.
|
477 |
-
'''
|
478 |
-
|
479 |
-
print(f"> Total split documents: {len(docs_out)}")
|
480 |
-
|
481 |
-
#print(docs_out)
|
482 |
-
|
483 |
-
page_contents = [doc.page_content for doc in docs_out]
|
484 |
-
page_meta = [doc.metadata for doc in docs_out]
|
485 |
-
ids_range = range(0,len(page_contents))
|
486 |
-
ids = [str(element) for element in ids_range]
|
487 |
-
|
488 |
-
tic = time.perf_counter()
|
489 |
-
#embeddings_list = []
|
490 |
-
#for page in progress.tqdm(page_contents, desc = "Preparing search index", unit = "rows"):
|
491 |
-
# embeddings_list.append(embeddings.encode(sentences=page, max_length=1024).tolist())
|
492 |
-
|
493 |
-
embeddings_list = embeddings.encode(sentences=page_contents, max_length=256, show_progress_bar = True, batch_size = 32).tolist() # For Jina embeddings
|
494 |
-
#embeddings_list = embeddings.encode(sentences=page_contents, normalize_embeddings=True).tolist() # For BGE embeddings
|
495 |
-
#embeddings_list = embeddings.encode(sentences=page_contents).tolist() # For minilm
|
496 |
-
|
497 |
-
toc = time.perf_counter()
|
498 |
-
time_out = f"The embedding took {toc - tic:0.1f} seconds"
|
499 |
-
|
500 |
-
#pd.Series(embeddings_list).to_csv("embeddings_out.csv")
|
501 |
-
|
502 |
-
# Jina tiny
|
503 |
-
# This takes about 300 seconds for 240,000 records = 800 / second, 1024 max length
|
504 |
-
# For 50k records:
|
505 |
-
# 61 seconds at 1024 max length
|
506 |
-
# 55 seconds at 512 max length
|
507 |
-
# 43 seconds at 256 max length
|
508 |
-
# 31 seconds at 128 max length
|
509 |
-
|
510 |
-
# The embedding took 1372.5 seconds at 256 max length for 655,020 case notes
|
511 |
-
|
512 |
-
# BGE small
|
513 |
-
# 96 seconds for 50k records at 512 length
|
514 |
-
|
515 |
-
# all-MiniLM-L6-v2
|
516 |
-
# 42.5 seconds at (256?) max length
|
517 |
-
|
518 |
-
# paraphrase-MiniLM-L3-v2
|
519 |
-
# 22 seconds for 128 max length
|
520 |
-
|
521 |
-
|
522 |
-
print(time_out)
|
523 |
-
|
524 |
-
chroma_tic = time.perf_counter()
|
525 |
-
|
526 |
-
# Create a new Chroma collection to store the documents and metadata. We don't need to specify an embedding fuction, and the default will be used.
|
527 |
-
client = chromadb.PersistentClient(path="./last_year", settings=Settings(
|
528 |
-
anonymized_telemetry=False))
|
529 |
-
|
530 |
-
try:
|
531 |
-
print("Deleting existing collection.")
|
532 |
-
#collection = client.get_collection(name="my_collection")
|
533 |
-
client.delete_collection(name="my_collection")
|
534 |
-
print("Creating new collection.")
|
535 |
-
collection = client.create_collection(name="my_collection")
|
536 |
-
except:
|
537 |
-
print("Creating new collection.")
|
538 |
-
collection = client.create_collection(name="my_collection")
|
539 |
-
|
540 |
-
# Match batch size is about 40,000, so add that amount in a loop
|
541 |
-
def create_batch_ranges(in_list, batch_size=40000):
|
542 |
-
total_rows = len(in_list)
|
543 |
-
ranges = []
|
544 |
-
|
545 |
-
for start in range(0, total_rows, batch_size):
|
546 |
-
end = min(start + batch_size, total_rows)
|
547 |
-
ranges.append(range(start, end))
|
548 |
-
|
549 |
-
return ranges
|
550 |
-
|
551 |
-
batch_ranges = create_batch_ranges(embeddings_list)
|
552 |
-
print(batch_ranges)
|
553 |
-
|
554 |
-
for row_range in progress.tqdm(batch_ranges, desc = "Creating vector database", unit = "batches of 40,000 rows"):
|
555 |
-
|
556 |
-
collection.add(
|
557 |
-
documents = page_contents[row_range[0]:row_range[-1]],
|
558 |
-
embeddings = embeddings_list[row_range[0]:row_range[-1]],
|
559 |
-
metadatas = page_meta[row_range[0]:row_range[-1]],
|
560 |
-
ids = ids[row_range[0]:row_range[-1]])
|
561 |
-
#print("Here")
|
562 |
-
|
563 |
-
# print(collection.count())
|
564 |
-
|
565 |
-
|
566 |
-
#chatf.vectorstore = vectorstore_func
|
567 |
-
|
568 |
-
chroma_toc = time.perf_counter()
|
569 |
-
|
570 |
-
chroma_time_out = f"Loading to Chroma db took {chroma_toc - chroma_tic:0.1f} seconds"
|
571 |
-
print(chroma_time_out)
|
572 |
-
|
573 |
-
out_message = "Document processing complete"
|
574 |
-
|
575 |
-
return out_message, collection
|
576 |
-
|
577 |
-
def chroma_retrieval_deprecated(query_str:str, vectorstore, docs, orig_df_col:str, k_val:int, out_passages:int,
|
578 |
-
vec_score_cut_off:float, vec_weight:float, in_join_file = None, in_join_column = None, search_df_join_column = None, embeddings = embeddings_model): # ,vectorstore, embeddings
|
579 |
-
|
580 |
-
query = embeddings.encode(query_str).tolist()
|
581 |
-
|
582 |
-
docs = vectorstore.query(
|
583 |
-
query_embeddings=query,
|
584 |
-
n_results= k_val # No practical limit on number of responses returned
|
585 |
-
#where={"metadata_field": "is_equal_to_this"},
|
586 |
-
#where_document={"$contains":"search_string"}
|
587 |
-
)
|
588 |
-
|
589 |
-
df_docs = pd.DataFrame(data={'ids': docs['ids'][0],
|
590 |
-
'documents': docs['documents'][0],
|
591 |
-
'metadatas':docs['metadatas'][0],
|
592 |
-
'distances':docs['distances'][0]#,
|
593 |
-
#'embeddings': docs['embeddings']
|
594 |
-
})
|
595 |
-
|
596 |
-
results_df_out = process_data_from_scores_df(df_docs, in_join_file, out_passages, vec_score_cut_off, vec_weight, orig_df_col, in_join_column, search_df_join_column)
|
597 |
-
|
598 |
-
results_df_name = output_folder + "semantic_search_result.csv"
|
599 |
-
results_df_out.to_csv(results_df_name, index= None)
|
600 |
-
results_first_text = results_df_out[orig_df_col].iloc[0]
|
601 |
-
|
602 |
-
return results_first_text, results_df_name
|
|
|
5 |
import gradio as gr
|
6 |
import numpy as np
|
7 |
from datetime import datetime
|
8 |
+
from search_funcs.helper_functions import get_file_path_end, create_highlighted_excel_wb, ensure_output_folder_exists, output_folder
|
9 |
+
from torch import cuda, backends
|
|
|
|
|
10 |
from sentence_transformers import SentenceTransformer
|
11 |
+
PandasDataFrame = Type[pd.DataFrame]
|
12 |
|
13 |
today_rev = datetime.now().strftime("%Y%m%d")
|
14 |
|
|
|
24 |
|
25 |
print("Device used is: ", torch_device)
|
26 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
# Load embeddings
|
28 |
embeddings_name = "BAAI/bge-small-en-v1.5"
|
29 |
|
|
|
48 |
embeddings_model = SentenceTransformer(embeddings_name)
|
49 |
print("Could not find local model installation. Downloading from Huggingface")
|
50 |
|
51 |
+
|
52 |
+
def docs_to_bge_embed_np_array(
|
53 |
+
docs_out: list,
|
54 |
+
in_file: list,
|
55 |
+
embeddings_state: np.ndarray,
|
56 |
+
output_file_state: str,
|
57 |
+
clean: str,
|
58 |
+
return_intermediate_files: str = "No",
|
59 |
+
embeddings_super_compress: str = "No",
|
60 |
+
embeddings_model: SentenceTransformer = embeddings_model,
|
61 |
+
progress: gr.Progress = gr.Progress(track_tqdm=True)
|
62 |
+
) -> tuple:
|
63 |
+
"""
|
64 |
+
Process documents to create BGE embeddings and save them as a numpy array.
|
65 |
+
|
66 |
+
Parameters:
|
67 |
+
- docs_out (list): List of documents to be embedded.
|
68 |
+
- in_file (list): List of input files.
|
69 |
+
- embeddings_state (np.ndarray): Current state of embeddings.
|
70 |
+
- output_file_state (str): State of the output file.
|
71 |
+
- clean (str): Indicates if the data should be cleaned.
|
72 |
+
- return_intermediate_files (str, optional): Whether to return intermediate files. Default is "No".
|
73 |
+
- embeddings_super_compress (str, optional): Whether to super compress the embeddings. Default is "No".
|
74 |
+
- embeddings_model (SentenceTransformer, optional): The embeddings model to use. Default is embeddings_model.
|
75 |
+
- progress (gr.Progress, optional): Progress tracker for the function. Default is gr.Progress(track_tqdm=True).
|
76 |
+
|
77 |
+
Returns:
|
78 |
+
- tuple: A tuple containing the output message, embeddings, and output file state.
|
79 |
+
"""
|
80 |
+
|
81 |
|
82 |
ensure_output_folder_exists(output_folder)
|
83 |
|
84 |
if not in_file:
|
85 |
out_message = "No input file found. Please load in at least one file."
|
86 |
print(out_message)
|
87 |
+
return out_message, None, None, output_file_state
|
|
|
88 |
|
89 |
progress(0.6, desc = "Loading/creating embeddings")
|
90 |
|
91 |
print(f"> Total split documents: {len(docs_out)}")
|
92 |
|
|
|
|
|
93 |
page_contents = [doc.page_content for doc in docs_out]
|
94 |
|
95 |
## Load in pre-embedded file if exists
|
96 |
file_list = [string.name for string in in_file]
|
97 |
|
|
|
|
|
98 |
embeddings_file_names = [string for string in file_list if "embedding" in string.lower()]
|
99 |
data_file_names = [string for string in file_list if "tokenised" not in string.lower() and "npz" not in string.lower()]# and "gz" not in string.lower()]
|
100 |
data_file_name = data_file_names[0]
|
|
|
102 |
|
103 |
out_message = "Document processing complete. Ready to search."
|
104 |
|
|
|
105 |
|
106 |
if embeddings_state.size == 0:
|
107 |
tic = time.perf_counter()
|
108 |
print("Starting to embed documents.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
109 |
|
110 |
embeddings_out = embeddings_model.encode(sentences=page_contents, show_progress_bar = True, batch_size = 32, normalize_embeddings=True) # For BGE
|
|
|
|
|
111 |
|
112 |
toc = time.perf_counter()
|
113 |
time_out = f"The embedding took {toc - tic:0.1f} seconds"
|
|
|
141 |
|
142 |
return out_message, embeddings_out, output_file_state, output_file_state
|
143 |
|
144 |
+
def process_data_from_scores_df(
|
145 |
+
df_docs: pd.DataFrame,
|
146 |
+
in_join_file: pd.DataFrame,
|
147 |
+
vec_score_cut_off: float,
|
148 |
+
in_join_column: str,
|
149 |
+
search_df_join_column: str,
|
150 |
+
progress: gr.Progress = gr.Progress(track_tqdm=True)
|
151 |
+
) -> pd.DataFrame:
|
152 |
+
"""
|
153 |
+
Process the data from the scores DataFrame by filtering based on score cutoff and document length,
|
154 |
+
and optionally joining with an additional file.
|
155 |
+
|
156 |
+
Parameters
|
157 |
+
----------
|
158 |
+
df_docs : pd.DataFrame
|
159 |
+
DataFrame containing document scores and metadata.
|
160 |
+
in_join_file : pd.DataFrame
|
161 |
+
DataFrame to join with the results based on specified columns.
|
162 |
+
vec_score_cut_off : float
|
163 |
+
Cutoff value for the vector similarity score.
|
164 |
+
in_join_column : str
|
165 |
+
Column name in the join file to join on.
|
166 |
+
search_df_join_column : str
|
167 |
+
Column name in the search DataFrame to join on.
|
168 |
+
progress : gr.Progress, optional
|
169 |
+
Progress tracker for the function (default is gr.Progress(track_tqdm=True)).
|
170 |
+
|
171 |
+
Returns
|
172 |
+
-------
|
173 |
+
pd.DataFrame
|
174 |
+
Processed DataFrame with filtered and joined data.
|
175 |
+
"""
|
176 |
+
|
177 |
docs_scores = df_docs["distances"] #.astype(float)
|
178 |
|
179 |
# Only keep sources that are sufficiently relevant (i.e. similarity search score below threshold below)
|
180 |
score_more_limit = df_docs.loc[docs_scores > vec_score_cut_off, :]
|
|
|
|
|
|
|
181 |
|
182 |
if score_more_limit.empty:
|
183 |
return pd.DataFrame()
|
|
|
185 |
# Only keep sources that are at least 100 characters long
|
186 |
docs_len = score_more_limit["documents"].str.len() >= 100
|
187 |
|
|
|
|
|
188 |
length_more_limit = score_more_limit.loc[docs_len == True, :] #pd.Series(docs_len) >= 100
|
|
|
|
|
|
|
189 |
|
190 |
if length_more_limit.empty:
|
191 |
return pd.DataFrame()
|
192 |
|
193 |
length_more_limit['ids'] = length_more_limit['ids'].astype(int)
|
194 |
|
|
|
195 |
|
196 |
# Explode the 'metadatas' dictionary into separate columns
|
197 |
df_metadata_expanded = length_more_limit['metadatas'].apply(pd.Series)
|
198 |
|
|
|
|
|
|
|
199 |
# Concatenate the original DataFrame with the expanded metadata DataFrame
|
200 |
results_df_out = pd.concat([length_more_limit.drop('metadatas', axis=1), df_metadata_expanded], axis=1)
|
201 |
|
|
|
205 |
results_df_out['distances'] = round(results_df_out['distances'].astype(float), 3)
|
206 |
|
207 |
|
|
|
|
|
|
|
208 |
# Join on additional files
|
209 |
if not in_join_file.empty:
|
210 |
progress(0.5, desc = "Joining on additional data file")
|
|
|
221 |
|
222 |
return results_df_out
|
223 |
|
224 |
+
def bge_semantic_search(
|
225 |
+
query_str: str,
|
226 |
+
embeddings: np.ndarray,
|
227 |
+
documents: list,
|
228 |
+
k_val: int,
|
229 |
+
vec_score_cut_off: float,
|
230 |
+
in_join_file: pd.DataFrame,
|
231 |
+
in_join_column: str = None,
|
232 |
+
search_df_join_column: str = None,
|
233 |
+
device: str = torch_device,
|
234 |
+
embeddings_model: SentenceTransformer = embeddings_model,
|
235 |
+
progress: gr.Progress = gr.Progress(track_tqdm=True)
|
236 |
+
) -> pd.DataFrame:
|
237 |
+
"""
|
238 |
+
Perform a semantic search using the BGE model.
|
239 |
+
|
240 |
+
Parameters:
|
241 |
+
- query_str (str): The query string to search for.
|
242 |
+
- embeddings (np.ndarray): The embeddings to search within.
|
243 |
+
- documents (list): The list of documents to search.
|
244 |
+
- k_val (int): The number of top results to return.
|
245 |
+
- vec_score_cut_off (float): The score cutoff for filtering results.
|
246 |
+
- in_join_file (pd.DataFrame): The DataFrame to join with the search results.
|
247 |
+
- in_join_column (str, optional): The column name in the join DataFrame to join on. Default is None.
|
248 |
+
- search_df_join_column (str, optional): The column name in the search DataFrame to join on. Default is None.
|
249 |
+
- device (str, optional): The device to run the model on. Default is torch_device.
|
250 |
+
- embeddings_model (SentenceTransformer, optional): The embeddings model to use. Default is embeddings_model.
|
251 |
+
- progress (gr.Progress, optional): Progress tracker for the function. Default is gr.Progress(track_tqdm=True).
|
252 |
+
|
253 |
+
Returns:
|
254 |
+
- pd.DataFrame: The DataFrame containing the search results.
|
255 |
+
"""
|
256 |
|
|
|
257 |
progress(0, desc = "Conducting semantic search")
|
258 |
|
259 |
ensure_output_folder_exists(output_folder)
|
260 |
|
261 |
print("Searching")
|
262 |
|
|
|
|
|
|
|
263 |
# Load the sentence transformer model and move it to GPU
|
264 |
+
embeddings_model = embeddings_model.to(device)
|
265 |
|
266 |
# Encode the query using the sentence transformer and convert to a PyTorch tensor
|
267 |
+
query = embeddings_model.encode(query_str, normalize_embeddings=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
268 |
|
269 |
# Sentence transformers method, not used:
|
270 |
+
cosine_similarities = query @ embeddings.T
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
271 |
|
272 |
# Flatten the tensor to a 1D array
|
273 |
cosine_similarities = cosine_similarities.flatten()
|
274 |
|
|
|
|
|
|
|
275 |
# Create a Pandas Series
|
276 |
cosine_similarities_series = pd.Series(cosine_similarities)
|
277 |
|
278 |
+
# Pull out relevent info from documents
|
279 |
+
page_contents = [doc.page_content for doc in documents]
|
280 |
+
page_meta = [doc.metadata for doc in documents]
|
281 |
ids_range = range(0,len(page_contents))
|
282 |
ids = [str(element) for element in ids_range]
|
283 |
|
284 |
+
df_documents = pd.DataFrame(data={"ids": ids,
|
285 |
"documents": page_contents,
|
286 |
"metadatas":page_meta,
|
287 |
"distances":cosine_similarities_series}).sort_values("distances", ascending=False).iloc[0:k_val,:]
|
288 |
|
289 |
|
290 |
+
results_df_out = process_data_from_scores_df(df_documents, in_join_file, vec_score_cut_off, in_join_column, search_df_join_column)
|
291 |
|
292 |
print("Search complete")
|
293 |
|
|
|
311 |
|
312 |
print("Returning results")
|
313 |
|
314 |
+
return results_first_text, results_df_name
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
search_funcs/spacy_search_funcs.py
CHANGED
@@ -27,9 +27,14 @@ except:
|
|
27 |
nlp = spacy.load("en_core_web_sm")
|
28 |
print("Successfully imported spaCy model")
|
29 |
|
30 |
-
def spacy_fuzzy_search(string_query:str,
|
31 |
''' Conduct fuzzy match on a list of data.'''
|
32 |
|
|
|
|
|
|
|
|
|
|
|
33 |
if len(df_list) > 10000:
|
34 |
out_message = "Your data has more than 10,000 rows and will take more than three minutes to do a fuzzy search. Please try keyword or semantic search for data of this size."
|
35 |
return out_message, None
|
|
|
27 |
nlp = spacy.load("en_core_web_sm")
|
28 |
print("Successfully imported spaCy model")
|
29 |
|
30 |
+
def spacy_fuzzy_search(string_query:str, tokenised_data: List[List[str]], original_data: PandasDataFrame, text_column:str, in_join_file: PandasDataFrame, search_df_join_column:str, in_join_column:str, no_spelling_mistakes:int = 1, progress=gr.Progress(track_tqdm=True)):
|
31 |
''' Conduct fuzzy match on a list of data.'''
|
32 |
|
33 |
+
#print("df_list:", df_list)
|
34 |
+
|
35 |
+
# Convert tokenised data back into a list of strings
|
36 |
+
df_list = list(map(" ".join, tokenised_data))
|
37 |
+
|
38 |
if len(df_list) > 10000:
|
39 |
out_message = "Your data has more than 10,000 rows and will take more than three minutes to do a fuzzy search. Please try keyword or semantic search for data of this size."
|
40 |
return out_message, None
|