Spaces:
Sleeping
Sleeping
Commit
·
36623c8
1
Parent(s):
44fcc42
Add cache
Browse files- benchmark.py +5 -4
- main.py +25 -14
- requirements old.txt +91 -0
- requirements.txt +30 -2
- retrieval_pipeline/cache.py +94 -0
- retrieval_pipeline/hybrid_search.py +15 -4
benchmark.py
CHANGED
@@ -7,6 +7,7 @@ TOP_N = 3
|
|
7 |
|
8 |
def get_benchmark_result(path, retriever):
|
9 |
df = pd.read_csv(path)
|
|
|
10 |
retrieval_result = []
|
11 |
query_result = [[] for i in range(TOP_N)]
|
12 |
retrieval_latency = []
|
@@ -21,13 +22,13 @@ def get_benchmark_result(path, retriever):
|
|
21 |
t0 = time.time()
|
22 |
results = retriever.get_relevant_documents(query)
|
23 |
t = time.time() - t0
|
24 |
-
retrieval_latency.append(t)
|
25 |
|
26 |
result_content = [result.page_content for result in results]
|
27 |
# results_content = get_relevant_documents(query, retriever, top_k=5)
|
28 |
|
29 |
for i, text in enumerate(result_content):
|
30 |
-
query_result[i]
|
31 |
|
32 |
if target in result_content:
|
33 |
retrieval_result.append("Success")
|
@@ -37,10 +38,10 @@ def get_benchmark_result(path, retriever):
|
|
37 |
# break
|
38 |
|
39 |
df["retrieval_result"] = retrieval_result
|
40 |
-
df["retrieval_latency"] = retrieval_latency
|
41 |
for i in range(TOP_N):
|
42 |
df[f'q{i+1}'] = query_result[i]
|
43 |
-
|
44 |
print(df['retrieval_result'].value_counts())
|
45 |
print(df['retrieval_result'].value_counts()/ len(df))
|
46 |
|
|
|
7 |
|
8 |
def get_benchmark_result(path, retriever):
|
9 |
df = pd.read_csv(path)
|
10 |
+
|
11 |
retrieval_result = []
|
12 |
query_result = [[] for i in range(TOP_N)]
|
13 |
retrieval_latency = []
|
|
|
22 |
t0 = time.time()
|
23 |
results = retriever.get_relevant_documents(query)
|
24 |
t = time.time() - t0
|
25 |
+
retrieval_latency.append(str(t))
|
26 |
|
27 |
result_content = [result.page_content for result in results]
|
28 |
# results_content = get_relevant_documents(query, retriever, top_k=5)
|
29 |
|
30 |
for i, text in enumerate(result_content):
|
31 |
+
query_result[i].append(text)
|
32 |
|
33 |
if target in result_content:
|
34 |
retrieval_result.append("Success")
|
|
|
38 |
# break
|
39 |
|
40 |
df["retrieval_result"] = retrieval_result
|
41 |
+
df["retrieval_latency"] = retrieval_latency
|
42 |
for i in range(TOP_N):
|
43 |
df[f'q{i+1}'] = query_result[i]
|
44 |
+
df.to_csv('benchmark_result.csv')
|
45 |
print(df['retrieval_result'].value_counts())
|
46 |
print(df['retrieval_result'].value_counts()/ len(df))
|
47 |
|
main.py
CHANGED
@@ -1,10 +1,13 @@
|
|
1 |
from dotenv import load_dotenv
|
2 |
import json
|
3 |
-
import os
|
4 |
import uuid
|
5 |
|
6 |
from retrieval_pipeline import get_retriever, get_compression_retriever
|
7 |
import benchmark
|
|
|
|
|
|
|
8 |
|
9 |
load_dotenv()
|
10 |
ELASTICSEARCH_URL = os.getenv('ELASTICSEARCH_URL')
|
@@ -16,18 +19,26 @@ print(ELASTICSEARCH_URL)
|
|
16 |
if __name__ == "__main__":
|
17 |
retriever = get_retriever(index='masa.ai', elasticsearch_url=ELASTICSEARCH_URL)
|
18 |
compression_retriever = get_compression_retriever(retriever)
|
|
|
|
|
|
|
19 |
retrieved_chunks = compression_retriever.get_relevant_documents('Gunung Semeru')
|
20 |
print(retrieved_chunks)
|
21 |
-
|
22 |
-
#
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from dotenv import load_dotenv
|
2 |
import json
|
3 |
+
import os, time
|
4 |
import uuid
|
5 |
|
6 |
from retrieval_pipeline import get_retriever, get_compression_retriever
|
7 |
import benchmark
|
8 |
+
from retrieval_pipeline.hybrid_search import store
|
9 |
+
|
10 |
+
from retrieval_pipeline.cache import SemanticCache
|
11 |
|
12 |
load_dotenv()
|
13 |
ELASTICSEARCH_URL = os.getenv('ELASTICSEARCH_URL')
|
|
|
19 |
if __name__ == "__main__":
|
20 |
retriever = get_retriever(index='masa.ai', elasticsearch_url=ELASTICSEARCH_URL)
|
21 |
compression_retriever = get_compression_retriever(retriever)
|
22 |
+
|
23 |
+
semantic_cache_retriever = SemanticCache(compression_retriever)
|
24 |
+
|
25 |
retrieved_chunks = compression_retriever.get_relevant_documents('Gunung Semeru')
|
26 |
print(retrieved_chunks)
|
27 |
+
|
28 |
+
# benchmark.get_benchmark_result("benchmark-reranker.csv", retriever=compression_retriever)
|
29 |
+
|
30 |
+
for i in range(100):
|
31 |
+
query = input("query: ")
|
32 |
+
t0 = time.time()
|
33 |
+
# retrieved_chunks = compression_retriever.get_relevant_documents(query)
|
34 |
+
retrieved_chunks = semantic_cache_retriever.get_relevant_documents(query)
|
35 |
+
|
36 |
+
t = time.time() - t0
|
37 |
+
|
38 |
+
print(list(store.yield_keys()))
|
39 |
+
print('time:', t)
|
40 |
+
|
41 |
+
print("Result:")
|
42 |
+
for r in retrieved_chunks:
|
43 |
+
print(r.page_content[:50])
|
44 |
+
print()
|
requirements old.txt
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
aiohttp==3.9.5
|
2 |
+
aiolimiter==1.1.0
|
3 |
+
aiosignal==1.3.1
|
4 |
+
altair==5.3.0
|
5 |
+
annotated-types==0.6.0
|
6 |
+
async-timeout==4.0.3
|
7 |
+
attrs==23.2.0
|
8 |
+
blinker==1.8.2
|
9 |
+
cachetools==5.3.3
|
10 |
+
certifi==2024.2.2
|
11 |
+
charset-normalizer==3.3.2
|
12 |
+
click==8.1.7
|
13 |
+
colorama==0.4.6
|
14 |
+
Cython==3.0.10
|
15 |
+
dataclasses-json==0.6.6
|
16 |
+
elastic-transport==8.13.0
|
17 |
+
elasticsearch==8.13.1
|
18 |
+
filelock==3.14.0
|
19 |
+
frozenlist==1.4.1
|
20 |
+
fsspec==2024.3.1
|
21 |
+
gitdb==4.0.11
|
22 |
+
GitPython==3.1.43
|
23 |
+
greenlet==3.0.3
|
24 |
+
huggingface-hub==0.23.0
|
25 |
+
idna==3.7
|
26 |
+
intel-openmp==2021.4.0
|
27 |
+
Jinja2==3.1.4
|
28 |
+
joblib==1.4.2
|
29 |
+
jsonpatch==1.33
|
30 |
+
jsonpointer==2.4
|
31 |
+
jsonschema==4.22.0
|
32 |
+
jsonschema-specifications==2023.12.1
|
33 |
+
langchain==0.1.20
|
34 |
+
langchain-community==0.0.38
|
35 |
+
langchain-core==0.1.52
|
36 |
+
langchain-text-splitters==0.0.1
|
37 |
+
langsmith==0.1.57
|
38 |
+
markdown-it-py==3.0.0
|
39 |
+
MarkupSafe==2.1.5
|
40 |
+
marshmallow==3.21.2
|
41 |
+
mdurl==0.1.2
|
42 |
+
mkl==2021.4.0
|
43 |
+
mpmath==1.3.0
|
44 |
+
multidict==6.0.5
|
45 |
+
mypy-extensions==1.0.0
|
46 |
+
networkx==3.2.1
|
47 |
+
numpy==1.26.4
|
48 |
+
orjson==3.10.3
|
49 |
+
packaging==23.2
|
50 |
+
pandas==2.2.2
|
51 |
+
pillow==10.3.0
|
52 |
+
protobuf==4.25.3
|
53 |
+
pyarrow==16.1.0
|
54 |
+
pydantic==2.7.1
|
55 |
+
pydantic_core==2.18.2
|
56 |
+
pydeck==0.9.1
|
57 |
+
Pygments==2.18.0
|
58 |
+
python-dateutil==2.9.0.post0
|
59 |
+
python-dotenv==1.0.1
|
60 |
+
pytz==2024.1
|
61 |
+
PyYAML==6.0.1
|
62 |
+
referencing==0.35.1
|
63 |
+
regex==2024.5.10
|
64 |
+
requests==2.31.0
|
65 |
+
rich==13.7.1
|
66 |
+
rpds-py==0.18.1
|
67 |
+
safetensors==0.4.3
|
68 |
+
scikit-learn==1.4.2
|
69 |
+
scipy==1.13.0
|
70 |
+
sentence-transformers==2.7.0
|
71 |
+
six==1.16.0
|
72 |
+
smmap==5.0.1
|
73 |
+
SQLAlchemy==2.0.30
|
74 |
+
streamlit==1.34.0
|
75 |
+
sympy==1.12
|
76 |
+
tbb==2021.12.0
|
77 |
+
tenacity==8.3.0
|
78 |
+
threadpoolctl==3.5.0
|
79 |
+
tokenizers==0.19.1
|
80 |
+
toml==0.10.2
|
81 |
+
toolz==0.12.1
|
82 |
+
torch==2.3.0
|
83 |
+
tornado==6.4
|
84 |
+
tqdm==4.66.4
|
85 |
+
transformers==4.40.2
|
86 |
+
typing-inspect==0.9.0
|
87 |
+
typing_extensions==4.11.0
|
88 |
+
tzdata==2024.1
|
89 |
+
urllib3==2.2.1
|
90 |
+
watchdog==4.0.0
|
91 |
+
yarl==1.9.4
|
requirements.txt
CHANGED
@@ -3,6 +3,7 @@ aiolimiter==1.1.0
|
|
3 |
aiosignal==1.3.1
|
4 |
altair==5.3.0
|
5 |
annotated-types==0.6.0
|
|
|
6 |
async-timeout==4.0.3
|
7 |
attrs==23.2.0
|
8 |
blinker==1.8.2
|
@@ -11,17 +12,31 @@ certifi==2024.2.2
|
|
11 |
charset-normalizer==3.3.2
|
12 |
click==8.1.7
|
13 |
colorama==0.4.6
|
|
|
14 |
Cython==3.0.10
|
15 |
dataclasses-json==0.6.6
|
16 |
elastic-transport==8.13.0
|
17 |
elasticsearch==8.13.1
|
|
|
|
|
|
|
18 |
filelock==3.14.0
|
|
|
19 |
frozenlist==1.4.1
|
20 |
fsspec==2024.3.1
|
21 |
gitdb==4.0.11
|
22 |
GitPython==3.1.43
|
23 |
greenlet==3.0.3
|
24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
idna==3.7
|
26 |
intel-openmp==2021.4.0
|
27 |
Jinja2==3.1.4
|
@@ -35,6 +50,8 @@ langchain-community==0.0.38
|
|
35 |
langchain-core==0.1.52
|
36 |
langchain-text-splitters==0.0.1
|
37 |
langsmith==0.1.57
|
|
|
|
|
38 |
markdown-it-py==3.0.0
|
39 |
MarkupSafe==2.1.5
|
40 |
marshmallow==3.21.2
|
@@ -44,21 +61,29 @@ mpmath==1.3.0
|
|
44 |
multidict==6.0.5
|
45 |
mypy-extensions==1.0.0
|
46 |
networkx==3.2.1
|
|
|
47 |
numpy==1.26.4
|
|
|
|
|
48 |
orjson==3.10.3
|
49 |
packaging==23.2
|
50 |
pandas==2.2.2
|
51 |
pillow==10.3.0
|
52 |
-
|
|
|
53 |
pyarrow==16.1.0
|
54 |
pydantic==2.7.1
|
55 |
pydantic_core==2.18.2
|
56 |
pydeck==0.9.1
|
57 |
Pygments==2.18.0
|
|
|
58 |
python-dateutil==2.9.0.post0
|
59 |
python-dotenv==1.0.1
|
60 |
pytz==2024.1
|
|
|
61 |
PyYAML==6.0.1
|
|
|
|
|
62 |
referencing==0.35.1
|
63 |
regex==2024.5.10
|
64 |
requests==2.31.0
|
@@ -67,9 +92,11 @@ rpds-py==0.18.1
|
|
67 |
safetensors==0.4.3
|
68 |
scikit-learn==1.4.2
|
69 |
scipy==1.13.0
|
|
|
70 |
sentence-transformers==2.7.0
|
71 |
six==1.16.0
|
72 |
smmap==5.0.1
|
|
|
73 |
SQLAlchemy==2.0.30
|
74 |
streamlit==1.34.0
|
75 |
sympy==1.12
|
@@ -88,4 +115,5 @@ typing_extensions==4.11.0
|
|
88 |
tzdata==2024.1
|
89 |
urllib3==2.2.1
|
90 |
watchdog==4.0.0
|
|
|
91 |
yarl==1.9.4
|
|
|
3 |
aiosignal==1.3.1
|
4 |
altair==5.3.0
|
5 |
annotated-types==0.6.0
|
6 |
+
anyio==4.3.0
|
7 |
async-timeout==4.0.3
|
8 |
attrs==23.2.0
|
9 |
blinker==1.8.2
|
|
|
12 |
charset-normalizer==3.3.2
|
13 |
click==8.1.7
|
14 |
colorama==0.4.6
|
15 |
+
coloredlogs==15.0.1
|
16 |
Cython==3.0.10
|
17 |
dataclasses-json==0.6.6
|
18 |
elastic-transport==8.13.0
|
19 |
elasticsearch==8.13.1
|
20 |
+
exceptiongroup==1.2.1
|
21 |
+
faiss-cpu==1.8.0
|
22 |
+
fastembed==0.2.6
|
23 |
filelock==3.14.0
|
24 |
+
flatbuffers==24.3.25
|
25 |
frozenlist==1.4.1
|
26 |
fsspec==2024.3.1
|
27 |
gitdb==4.0.11
|
28 |
GitPython==3.1.43
|
29 |
greenlet==3.0.3
|
30 |
+
grpcio==1.63.0
|
31 |
+
grpcio-tools==1.63.0
|
32 |
+
h11==0.14.0
|
33 |
+
h2==4.1.0
|
34 |
+
hpack==4.0.0
|
35 |
+
httpcore==1.0.5
|
36 |
+
httpx==0.27.0
|
37 |
+
huggingface-hub==0.20.3
|
38 |
+
humanfriendly==10.0
|
39 |
+
hyperframe==6.0.1
|
40 |
idna==3.7
|
41 |
intel-openmp==2021.4.0
|
42 |
Jinja2==3.1.4
|
|
|
50 |
langchain-core==0.1.52
|
51 |
langchain-text-splitters==0.0.1
|
52 |
langsmith==0.1.57
|
53 |
+
llvmlite==0.42.0
|
54 |
+
loguru==0.7.2
|
55 |
markdown-it-py==3.0.0
|
56 |
MarkupSafe==2.1.5
|
57 |
marshmallow==3.21.2
|
|
|
61 |
multidict==6.0.5
|
62 |
mypy-extensions==1.0.0
|
63 |
networkx==3.2.1
|
64 |
+
numba==0.59.1
|
65 |
numpy==1.26.4
|
66 |
+
onnx==1.16.0
|
67 |
+
onnxruntime==1.17.3
|
68 |
orjson==3.10.3
|
69 |
packaging==23.2
|
70 |
pandas==2.2.2
|
71 |
pillow==10.3.0
|
72 |
+
portalocker==2.8.2
|
73 |
+
protobuf==5.26.1
|
74 |
pyarrow==16.1.0
|
75 |
pydantic==2.7.1
|
76 |
pydantic_core==2.18.2
|
77 |
pydeck==0.9.1
|
78 |
Pygments==2.18.0
|
79 |
+
pyreadline3==3.4.1
|
80 |
python-dateutil==2.9.0.post0
|
81 |
python-dotenv==1.0.1
|
82 |
pytz==2024.1
|
83 |
+
pywin32==306
|
84 |
PyYAML==6.0.1
|
85 |
+
qdrant-client==1.9.1
|
86 |
+
rankerEval==0.2.0
|
87 |
referencing==0.35.1
|
88 |
regex==2024.5.10
|
89 |
requests==2.31.0
|
|
|
92 |
safetensors==0.4.3
|
93 |
scikit-learn==1.4.2
|
94 |
scipy==1.13.0
|
95 |
+
semantic-cache==0.1.1
|
96 |
sentence-transformers==2.7.0
|
97 |
six==1.16.0
|
98 |
smmap==5.0.1
|
99 |
+
sniffio==1.3.1
|
100 |
SQLAlchemy==2.0.30
|
101 |
streamlit==1.34.0
|
102 |
sympy==1.12
|
|
|
115 |
tzdata==2024.1
|
116 |
urllib3==2.2.1
|
117 |
watchdog==4.0.0
|
118 |
+
win32-setctime==1.1.0
|
119 |
yarl==1.9.4
|
retrieval_pipeline/cache.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import faiss
|
2 |
+
from sentence_transformers import SentenceTransformer
|
3 |
+
import time
|
4 |
+
import json
|
5 |
+
|
6 |
+
from langchain_core.documents import Document
|
7 |
+
|
8 |
+
def init_cache():
|
9 |
+
index = faiss.IndexFlatL2(1024)
|
10 |
+
if index.is_trained:
|
11 |
+
print("Index trained")
|
12 |
+
|
13 |
+
# Initialize Sentence Transformer model
|
14 |
+
encoder = SentenceTransformer("multilingual-e5-large")
|
15 |
+
|
16 |
+
return index, encoder
|
17 |
+
|
18 |
+
def retrieve_cache(json_file):
|
19 |
+
try:
|
20 |
+
with open(json_file, "r") as file:
|
21 |
+
cache = json.load(file)
|
22 |
+
except FileNotFoundError:
|
23 |
+
cache = {"query": [], "embeddings": [], "answers": [], "response_text": []}
|
24 |
+
|
25 |
+
return cache
|
26 |
+
|
27 |
+
def store_cache(json_file, cache):
|
28 |
+
with open(json_file, "w") as file:
|
29 |
+
json.dump(cache, file)
|
30 |
+
|
31 |
+
class SemanticCache:
|
32 |
+
def __init__(self, retriever, json_file="cache_file.json", thresold=0.35):
|
33 |
+
# Initialize Faiss index with Euclidean distance
|
34 |
+
self.retriever = retriever
|
35 |
+
self.index, self.encoder = init_cache()
|
36 |
+
|
37 |
+
# Set Euclidean distance threshold
|
38 |
+
# a distance of 0 means identicals sentences
|
39 |
+
# We only return from cache sentences under this thresold
|
40 |
+
self.euclidean_threshold = thresold
|
41 |
+
|
42 |
+
self.json_file = json_file
|
43 |
+
self.cache = retrieve_cache(self.json_file)
|
44 |
+
|
45 |
+
def query_database(self, query_text):
|
46 |
+
results = self.retriever.get_relevant_documents(query_text)
|
47 |
+
return results
|
48 |
+
|
49 |
+
def get_relevant_documents(self, query: str) -> str:
|
50 |
+
# Method to retrieve an answer from the cache or generate a new one
|
51 |
+
start_time = time.time()
|
52 |
+
# try:
|
53 |
+
# First we obtain the embeddings corresponding to the user query
|
54 |
+
embedding = self.encoder.encode([query])
|
55 |
+
|
56 |
+
# Search for the nearest neighbor in the index
|
57 |
+
self.index.nprobe = 8
|
58 |
+
D, I = self.index.search(embedding, 1)
|
59 |
+
|
60 |
+
if D[0] >= 0:
|
61 |
+
if I[0][0] >= 0 and D[0][0] <= self.euclidean_threshold:
|
62 |
+
row_id = int(I[0][0])
|
63 |
+
|
64 |
+
print("Answer recovered from Cache. ")
|
65 |
+
print(f"{D[0][0]:.3f} smaller than {self.euclidean_threshold}")
|
66 |
+
print(f"Found cache in row: {row_id} with score {D[0][0]:.3f}")
|
67 |
+
|
68 |
+
end_time = time.time()
|
69 |
+
elapsed_time = end_time - start_time
|
70 |
+
print(f"Time taken: {elapsed_time:.3f} seconds")
|
71 |
+
return [Document(**doc[k]) for doc in self.cache["answers"][row_id]]
|
72 |
+
|
73 |
+
# Handle the case when there are not enough results
|
74 |
+
# or Euclidean distance is not met, asking to chromaDB.
|
75 |
+
answer = self.query_database(query)
|
76 |
+
# response_text = answer["documents"][0][0]
|
77 |
+
|
78 |
+
self.cache["query"].append(query)
|
79 |
+
self.cache["embeddings"].append(embedding[0].tolist())
|
80 |
+
self.cache["answers"].append([doc.__dict__ for doc in answer])
|
81 |
+
# self.cache["response_text"].append(response_text)
|
82 |
+
|
83 |
+
print("Answer recovered from ChromaDB. ")
|
84 |
+
# print(f"response_text: {response_text}")
|
85 |
+
|
86 |
+
self.index.add(embedding)
|
87 |
+
store_cache(self.json_file, self.cache)
|
88 |
+
end_time = time.time()
|
89 |
+
elapsed_time = end_time - start_time
|
90 |
+
print(f"Time taken: {elapsed_time:.3f} seconds")
|
91 |
+
|
92 |
+
return answer
|
93 |
+
# except Exception as e:
|
94 |
+
# raise RuntimeError(f"Error during 'get_relevant_documents' method: {e}")
|
retrieval_pipeline/hybrid_search.py
CHANGED
@@ -9,6 +9,10 @@ import elasticsearch
|
|
9 |
|
10 |
from typing import Optional, List
|
11 |
|
|
|
|
|
|
|
|
|
12 |
|
13 |
class HybridRetriever(BaseRetriever):
|
14 |
dense_db: ElasticVectorSearch
|
@@ -68,10 +72,6 @@ class HybridRetriever(BaseRetriever):
|
|
68 |
|
69 |
# Combine results (you'll need a strategy here)
|
70 |
combined_results = dense_results + sparse_results
|
71 |
-
# result_text = [doc.page_content for doc in combined_results]
|
72 |
-
|
73 |
-
# reranked_result = rerank.rerank(query, documents=result_text, model="rerank-lite-1", top_k=self.top_k_dense+self.top_k_sparse)
|
74 |
-
# reranked_result = sorted(reranked_result.results, key=lambda result: result.index)
|
75 |
|
76 |
# Create LangChain Documents
|
77 |
documents = [Document(page_content=doc.page_content, metadata=doc.metadata) for doc in combined_results]
|
@@ -82,10 +82,21 @@ class HybridRetriever(BaseRetriever):
|
|
82 |
raise NotImplementedError
|
83 |
|
84 |
def get_dense_db(elasticsearch_url, index_dense, embeddings):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
dense_db = ElasticVectorSearch(
|
86 |
elasticsearch_url=elasticsearch_url,
|
87 |
index_name=index_dense,
|
88 |
embedding=embeddings,
|
|
|
89 |
)
|
90 |
return dense_db
|
91 |
|
|
|
9 |
|
10 |
from typing import Optional, List
|
11 |
|
12 |
+
from langchain.storage import LocalFileStore
|
13 |
+
from langchain.embeddings import CacheBackedEmbeddings
|
14 |
+
|
15 |
+
store = LocalFileStore("cache")
|
16 |
|
17 |
class HybridRetriever(BaseRetriever):
|
18 |
dense_db: ElasticVectorSearch
|
|
|
72 |
|
73 |
# Combine results (you'll need a strategy here)
|
74 |
combined_results = dense_results + sparse_results
|
|
|
|
|
|
|
|
|
75 |
|
76 |
# Create LangChain Documents
|
77 |
documents = [Document(page_content=doc.page_content, metadata=doc.metadata) for doc in combined_results]
|
|
|
82 |
raise NotImplementedError
|
83 |
|
84 |
def get_dense_db(elasticsearch_url, index_dense, embeddings):
|
85 |
+
# retriever cache
|
86 |
+
cached_embedder = CacheBackedEmbeddings.from_bytes_store(
|
87 |
+
embeddings, store,
|
88 |
+
namespace='sentence-transformer',
|
89 |
+
# query_embedding_store=store,
|
90 |
+
# query_embedding_cache=True
|
91 |
+
)
|
92 |
+
|
93 |
+
cached_embedder.query_embedding_store = store
|
94 |
+
|
95 |
dense_db = ElasticVectorSearch(
|
96 |
elasticsearch_url=elasticsearch_url,
|
97 |
index_name=index_dense,
|
98 |
embedding=embeddings,
|
99 |
+
# embedding=cached_embedder,
|
100 |
)
|
101 |
return dense_db
|
102 |
|