manfredmichael commited on
Commit
36623c8
·
1 Parent(s): 44fcc42
benchmark.py CHANGED
@@ -7,6 +7,7 @@ TOP_N = 3
7
 
8
  def get_benchmark_result(path, retriever):
9
  df = pd.read_csv(path)
 
10
  retrieval_result = []
11
  query_result = [[] for i in range(TOP_N)]
12
  retrieval_latency = []
@@ -21,13 +22,13 @@ def get_benchmark_result(path, retriever):
21
  t0 = time.time()
22
  results = retriever.get_relevant_documents(query)
23
  t = time.time() - t0
24
- retrieval_latency.append(t)
25
 
26
  result_content = [result.page_content for result in results]
27
  # results_content = get_relevant_documents(query, retriever, top_k=5)
28
 
29
  for i, text in enumerate(result_content):
30
- query_result[i] = text
31
 
32
  if target in result_content:
33
  retrieval_result.append("Success")
@@ -37,10 +38,10 @@ def get_benchmark_result(path, retriever):
37
  # break
38
 
39
  df["retrieval_result"] = retrieval_result
40
- df["retrieval_latency"] = retrieval_latency
41
  for i in range(TOP_N):
42
  df[f'q{i+1}'] = query_result[i]
43
- df.to_csv('benchmark_result q3 topk 5.csv')
44
  print(df['retrieval_result'].value_counts())
45
  print(df['retrieval_result'].value_counts()/ len(df))
46
 
 
7
 
8
  def get_benchmark_result(path, retriever):
9
  df = pd.read_csv(path)
10
+
11
  retrieval_result = []
12
  query_result = [[] for i in range(TOP_N)]
13
  retrieval_latency = []
 
22
  t0 = time.time()
23
  results = retriever.get_relevant_documents(query)
24
  t = time.time() - t0
25
+ retrieval_latency.append(str(t))
26
 
27
  result_content = [result.page_content for result in results]
28
  # results_content = get_relevant_documents(query, retriever, top_k=5)
29
 
30
  for i, text in enumerate(result_content):
31
+ query_result[i].append(text)
32
 
33
  if target in result_content:
34
  retrieval_result.append("Success")
 
38
  # break
39
 
40
  df["retrieval_result"] = retrieval_result
41
+ df["retrieval_latency"] = retrieval_latency
42
  for i in range(TOP_N):
43
  df[f'q{i+1}'] = query_result[i]
44
+ df.to_csv('benchmark_result.csv')
45
  print(df['retrieval_result'].value_counts())
46
  print(df['retrieval_result'].value_counts()/ len(df))
47
 
main.py CHANGED
@@ -1,10 +1,13 @@
1
  from dotenv import load_dotenv
2
  import json
3
- import os
4
  import uuid
5
 
6
  from retrieval_pipeline import get_retriever, get_compression_retriever
7
  import benchmark
 
 
 
8
 
9
  load_dotenv()
10
  ELASTICSEARCH_URL = os.getenv('ELASTICSEARCH_URL')
@@ -16,18 +19,26 @@ print(ELASTICSEARCH_URL)
16
  if __name__ == "__main__":
17
  retriever = get_retriever(index='masa.ai', elasticsearch_url=ELASTICSEARCH_URL)
18
  compression_retriever = get_compression_retriever(retriever)
 
 
 
19
  retrieved_chunks = compression_retriever.get_relevant_documents('Gunung Semeru')
20
  print(retrieved_chunks)
21
-
22
- # retrieved_chunks = retriever.get_relevant_documents('Gunung Semeru')
23
- # print(retrieved_chunks)
24
-
25
- benchmark.get_benchmark_result("benchmark-reranker.csv", retriever=compression_retriever)
26
-
27
- # for i in range(100):
28
- # query = input("query: ")
29
- # retrieved_chunks = retriever.get_relevant_documents(query)
30
- # print("Result:")
31
- # for r in retrieved_chunks:
32
- # print(r.page_content[:50])
33
- # print()
 
 
 
 
 
 
1
  from dotenv import load_dotenv
2
  import json
3
+ import os, time
4
  import uuid
5
 
6
  from retrieval_pipeline import get_retriever, get_compression_retriever
7
  import benchmark
8
+ from retrieval_pipeline.hybrid_search import store
9
+
10
+ from retrieval_pipeline.cache import SemanticCache
11
 
12
  load_dotenv()
13
  ELASTICSEARCH_URL = os.getenv('ELASTICSEARCH_URL')
 
19
  if __name__ == "__main__":
20
  retriever = get_retriever(index='masa.ai', elasticsearch_url=ELASTICSEARCH_URL)
21
  compression_retriever = get_compression_retriever(retriever)
22
+
23
+ semantic_cache_retriever = SemanticCache(compression_retriever)
24
+
25
  retrieved_chunks = compression_retriever.get_relevant_documents('Gunung Semeru')
26
  print(retrieved_chunks)
27
+
28
+ # benchmark.get_benchmark_result("benchmark-reranker.csv", retriever=compression_retriever)
29
+
30
+ for i in range(100):
31
+ query = input("query: ")
32
+ t0 = time.time()
33
+ # retrieved_chunks = compression_retriever.get_relevant_documents(query)
34
+ retrieved_chunks = semantic_cache_retriever.get_relevant_documents(query)
35
+
36
+ t = time.time() - t0
37
+
38
+ print(list(store.yield_keys()))
39
+ print('time:', t)
40
+
41
+ print("Result:")
42
+ for r in retrieved_chunks:
43
+ print(r.page_content[:50])
44
+ print()
requirements old.txt ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiohttp==3.9.5
2
+ aiolimiter==1.1.0
3
+ aiosignal==1.3.1
4
+ altair==5.3.0
5
+ annotated-types==0.6.0
6
+ async-timeout==4.0.3
7
+ attrs==23.2.0
8
+ blinker==1.8.2
9
+ cachetools==5.3.3
10
+ certifi==2024.2.2
11
+ charset-normalizer==3.3.2
12
+ click==8.1.7
13
+ colorama==0.4.6
14
+ Cython==3.0.10
15
+ dataclasses-json==0.6.6
16
+ elastic-transport==8.13.0
17
+ elasticsearch==8.13.1
18
+ filelock==3.14.0
19
+ frozenlist==1.4.1
20
+ fsspec==2024.3.1
21
+ gitdb==4.0.11
22
+ GitPython==3.1.43
23
+ greenlet==3.0.3
24
+ huggingface-hub==0.23.0
25
+ idna==3.7
26
+ intel-openmp==2021.4.0
27
+ Jinja2==3.1.4
28
+ joblib==1.4.2
29
+ jsonpatch==1.33
30
+ jsonpointer==2.4
31
+ jsonschema==4.22.0
32
+ jsonschema-specifications==2023.12.1
33
+ langchain==0.1.20
34
+ langchain-community==0.0.38
35
+ langchain-core==0.1.52
36
+ langchain-text-splitters==0.0.1
37
+ langsmith==0.1.57
38
+ markdown-it-py==3.0.0
39
+ MarkupSafe==2.1.5
40
+ marshmallow==3.21.2
41
+ mdurl==0.1.2
42
+ mkl==2021.4.0
43
+ mpmath==1.3.0
44
+ multidict==6.0.5
45
+ mypy-extensions==1.0.0
46
+ networkx==3.2.1
47
+ numpy==1.26.4
48
+ orjson==3.10.3
49
+ packaging==23.2
50
+ pandas==2.2.2
51
+ pillow==10.3.0
52
+ protobuf==4.25.3
53
+ pyarrow==16.1.0
54
+ pydantic==2.7.1
55
+ pydantic_core==2.18.2
56
+ pydeck==0.9.1
57
+ Pygments==2.18.0
58
+ python-dateutil==2.9.0.post0
59
+ python-dotenv==1.0.1
60
+ pytz==2024.1
61
+ PyYAML==6.0.1
62
+ referencing==0.35.1
63
+ regex==2024.5.10
64
+ requests==2.31.0
65
+ rich==13.7.1
66
+ rpds-py==0.18.1
67
+ safetensors==0.4.3
68
+ scikit-learn==1.4.2
69
+ scipy==1.13.0
70
+ sentence-transformers==2.7.0
71
+ six==1.16.0
72
+ smmap==5.0.1
73
+ SQLAlchemy==2.0.30
74
+ streamlit==1.34.0
75
+ sympy==1.12
76
+ tbb==2021.12.0
77
+ tenacity==8.3.0
78
+ threadpoolctl==3.5.0
79
+ tokenizers==0.19.1
80
+ toml==0.10.2
81
+ toolz==0.12.1
82
+ torch==2.3.0
83
+ tornado==6.4
84
+ tqdm==4.66.4
85
+ transformers==4.40.2
86
+ typing-inspect==0.9.0
87
+ typing_extensions==4.11.0
88
+ tzdata==2024.1
89
+ urllib3==2.2.1
90
+ watchdog==4.0.0
91
+ yarl==1.9.4
requirements.txt CHANGED
@@ -3,6 +3,7 @@ aiolimiter==1.1.0
3
  aiosignal==1.3.1
4
  altair==5.3.0
5
  annotated-types==0.6.0
 
6
  async-timeout==4.0.3
7
  attrs==23.2.0
8
  blinker==1.8.2
@@ -11,17 +12,31 @@ certifi==2024.2.2
11
  charset-normalizer==3.3.2
12
  click==8.1.7
13
  colorama==0.4.6
 
14
  Cython==3.0.10
15
  dataclasses-json==0.6.6
16
  elastic-transport==8.13.0
17
  elasticsearch==8.13.1
 
 
 
18
  filelock==3.14.0
 
19
  frozenlist==1.4.1
20
  fsspec==2024.3.1
21
  gitdb==4.0.11
22
  GitPython==3.1.43
23
  greenlet==3.0.3
24
- huggingface-hub==0.23.0
 
 
 
 
 
 
 
 
 
25
  idna==3.7
26
  intel-openmp==2021.4.0
27
  Jinja2==3.1.4
@@ -35,6 +50,8 @@ langchain-community==0.0.38
35
  langchain-core==0.1.52
36
  langchain-text-splitters==0.0.1
37
  langsmith==0.1.57
 
 
38
  markdown-it-py==3.0.0
39
  MarkupSafe==2.1.5
40
  marshmallow==3.21.2
@@ -44,21 +61,29 @@ mpmath==1.3.0
44
  multidict==6.0.5
45
  mypy-extensions==1.0.0
46
  networkx==3.2.1
 
47
  numpy==1.26.4
 
 
48
  orjson==3.10.3
49
  packaging==23.2
50
  pandas==2.2.2
51
  pillow==10.3.0
52
- protobuf==4.25.3
 
53
  pyarrow==16.1.0
54
  pydantic==2.7.1
55
  pydantic_core==2.18.2
56
  pydeck==0.9.1
57
  Pygments==2.18.0
 
58
  python-dateutil==2.9.0.post0
59
  python-dotenv==1.0.1
60
  pytz==2024.1
 
61
  PyYAML==6.0.1
 
 
62
  referencing==0.35.1
63
  regex==2024.5.10
64
  requests==2.31.0
@@ -67,9 +92,11 @@ rpds-py==0.18.1
67
  safetensors==0.4.3
68
  scikit-learn==1.4.2
69
  scipy==1.13.0
 
70
  sentence-transformers==2.7.0
71
  six==1.16.0
72
  smmap==5.0.1
 
73
  SQLAlchemy==2.0.30
74
  streamlit==1.34.0
75
  sympy==1.12
@@ -88,4 +115,5 @@ typing_extensions==4.11.0
88
  tzdata==2024.1
89
  urllib3==2.2.1
90
  watchdog==4.0.0
 
91
  yarl==1.9.4
 
3
  aiosignal==1.3.1
4
  altair==5.3.0
5
  annotated-types==0.6.0
6
+ anyio==4.3.0
7
  async-timeout==4.0.3
8
  attrs==23.2.0
9
  blinker==1.8.2
 
12
  charset-normalizer==3.3.2
13
  click==8.1.7
14
  colorama==0.4.6
15
+ coloredlogs==15.0.1
16
  Cython==3.0.10
17
  dataclasses-json==0.6.6
18
  elastic-transport==8.13.0
19
  elasticsearch==8.13.1
20
+ exceptiongroup==1.2.1
21
+ faiss-cpu==1.8.0
22
+ fastembed==0.2.6
23
  filelock==3.14.0
24
+ flatbuffers==24.3.25
25
  frozenlist==1.4.1
26
  fsspec==2024.3.1
27
  gitdb==4.0.11
28
  GitPython==3.1.43
29
  greenlet==3.0.3
30
+ grpcio==1.63.0
31
+ grpcio-tools==1.63.0
32
+ h11==0.14.0
33
+ h2==4.1.0
34
+ hpack==4.0.0
35
+ httpcore==1.0.5
36
+ httpx==0.27.0
37
+ huggingface-hub==0.20.3
38
+ humanfriendly==10.0
39
+ hyperframe==6.0.1
40
  idna==3.7
41
  intel-openmp==2021.4.0
42
  Jinja2==3.1.4
 
50
  langchain-core==0.1.52
51
  langchain-text-splitters==0.0.1
52
  langsmith==0.1.57
53
+ llvmlite==0.42.0
54
+ loguru==0.7.2
55
  markdown-it-py==3.0.0
56
  MarkupSafe==2.1.5
57
  marshmallow==3.21.2
 
61
  multidict==6.0.5
62
  mypy-extensions==1.0.0
63
  networkx==3.2.1
64
+ numba==0.59.1
65
  numpy==1.26.4
66
+ onnx==1.16.0
67
+ onnxruntime==1.17.3
68
  orjson==3.10.3
69
  packaging==23.2
70
  pandas==2.2.2
71
  pillow==10.3.0
72
+ portalocker==2.8.2
73
+ protobuf==5.26.1
74
  pyarrow==16.1.0
75
  pydantic==2.7.1
76
  pydantic_core==2.18.2
77
  pydeck==0.9.1
78
  Pygments==2.18.0
79
+ pyreadline3==3.4.1
80
  python-dateutil==2.9.0.post0
81
  python-dotenv==1.0.1
82
  pytz==2024.1
83
+ pywin32==306
84
  PyYAML==6.0.1
85
+ qdrant-client==1.9.1
86
+ rankerEval==0.2.0
87
  referencing==0.35.1
88
  regex==2024.5.10
89
  requests==2.31.0
 
92
  safetensors==0.4.3
93
  scikit-learn==1.4.2
94
  scipy==1.13.0
95
+ semantic-cache==0.1.1
96
  sentence-transformers==2.7.0
97
  six==1.16.0
98
  smmap==5.0.1
99
+ sniffio==1.3.1
100
  SQLAlchemy==2.0.30
101
  streamlit==1.34.0
102
  sympy==1.12
 
115
  tzdata==2024.1
116
  urllib3==2.2.1
117
  watchdog==4.0.0
118
+ win32-setctime==1.1.0
119
  yarl==1.9.4
retrieval_pipeline/cache.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import faiss
2
+ from sentence_transformers import SentenceTransformer
3
+ import time
4
+ import json
5
+
6
+ from langchain_core.documents import Document
7
+
8
+ def init_cache():
9
+ index = faiss.IndexFlatL2(1024)
10
+ if index.is_trained:
11
+ print("Index trained")
12
+
13
+ # Initialize Sentence Transformer model
14
+ encoder = SentenceTransformer("multilingual-e5-large")
15
+
16
+ return index, encoder
17
+
18
+ def retrieve_cache(json_file):
19
+ try:
20
+ with open(json_file, "r") as file:
21
+ cache = json.load(file)
22
+ except FileNotFoundError:
23
+ cache = {"query": [], "embeddings": [], "answers": [], "response_text": []}
24
+
25
+ return cache
26
+
27
+ def store_cache(json_file, cache):
28
+ with open(json_file, "w") as file:
29
+ json.dump(cache, file)
30
+
31
+ class SemanticCache:
32
+ def __init__(self, retriever, json_file="cache_file.json", thresold=0.35):
33
+ # Initialize Faiss index with Euclidean distance
34
+ self.retriever = retriever
35
+ self.index, self.encoder = init_cache()
36
+
37
+ # Set Euclidean distance threshold
38
+ # a distance of 0 means identicals sentences
39
+ # We only return from cache sentences under this thresold
40
+ self.euclidean_threshold = thresold
41
+
42
+ self.json_file = json_file
43
+ self.cache = retrieve_cache(self.json_file)
44
+
45
+ def query_database(self, query_text):
46
+ results = self.retriever.get_relevant_documents(query_text)
47
+ return results
48
+
49
+ def get_relevant_documents(self, query: str) -> str:
50
+ # Method to retrieve an answer from the cache or generate a new one
51
+ start_time = time.time()
52
+ # try:
53
+ # First we obtain the embeddings corresponding to the user query
54
+ embedding = self.encoder.encode([query])
55
+
56
+ # Search for the nearest neighbor in the index
57
+ self.index.nprobe = 8
58
+ D, I = self.index.search(embedding, 1)
59
+
60
+ if D[0] >= 0:
61
+ if I[0][0] >= 0 and D[0][0] <= self.euclidean_threshold:
62
+ row_id = int(I[0][0])
63
+
64
+ print("Answer recovered from Cache. ")
65
+ print(f"{D[0][0]:.3f} smaller than {self.euclidean_threshold}")
66
+ print(f"Found cache in row: {row_id} with score {D[0][0]:.3f}")
67
+
68
+ end_time = time.time()
69
+ elapsed_time = end_time - start_time
70
+ print(f"Time taken: {elapsed_time:.3f} seconds")
71
+ return [Document(**doc[k]) for doc in self.cache["answers"][row_id]]
72
+
73
+ # Handle the case when there are not enough results
74
+ # or Euclidean distance is not met, asking to chromaDB.
75
+ answer = self.query_database(query)
76
+ # response_text = answer["documents"][0][0]
77
+
78
+ self.cache["query"].append(query)
79
+ self.cache["embeddings"].append(embedding[0].tolist())
80
+ self.cache["answers"].append([doc.__dict__ for doc in answer])
81
+ # self.cache["response_text"].append(response_text)
82
+
83
+ print("Answer recovered from ChromaDB. ")
84
+ # print(f"response_text: {response_text}")
85
+
86
+ self.index.add(embedding)
87
+ store_cache(self.json_file, self.cache)
88
+ end_time = time.time()
89
+ elapsed_time = end_time - start_time
90
+ print(f"Time taken: {elapsed_time:.3f} seconds")
91
+
92
+ return answer
93
+ # except Exception as e:
94
+ # raise RuntimeError(f"Error during 'get_relevant_documents' method: {e}")
retrieval_pipeline/hybrid_search.py CHANGED
@@ -9,6 +9,10 @@ import elasticsearch
9
 
10
  from typing import Optional, List
11
 
 
 
 
 
12
 
13
  class HybridRetriever(BaseRetriever):
14
  dense_db: ElasticVectorSearch
@@ -68,10 +72,6 @@ class HybridRetriever(BaseRetriever):
68
 
69
  # Combine results (you'll need a strategy here)
70
  combined_results = dense_results + sparse_results
71
- # result_text = [doc.page_content for doc in combined_results]
72
-
73
- # reranked_result = rerank.rerank(query, documents=result_text, model="rerank-lite-1", top_k=self.top_k_dense+self.top_k_sparse)
74
- # reranked_result = sorted(reranked_result.results, key=lambda result: result.index)
75
 
76
  # Create LangChain Documents
77
  documents = [Document(page_content=doc.page_content, metadata=doc.metadata) for doc in combined_results]
@@ -82,10 +82,21 @@ class HybridRetriever(BaseRetriever):
82
  raise NotImplementedError
83
 
84
  def get_dense_db(elasticsearch_url, index_dense, embeddings):
 
 
 
 
 
 
 
 
 
 
85
  dense_db = ElasticVectorSearch(
86
  elasticsearch_url=elasticsearch_url,
87
  index_name=index_dense,
88
  embedding=embeddings,
 
89
  )
90
  return dense_db
91
 
 
9
 
10
  from typing import Optional, List
11
 
12
+ from langchain.storage import LocalFileStore
13
+ from langchain.embeddings import CacheBackedEmbeddings
14
+
15
+ store = LocalFileStore("cache")
16
 
17
  class HybridRetriever(BaseRetriever):
18
  dense_db: ElasticVectorSearch
 
72
 
73
  # Combine results (you'll need a strategy here)
74
  combined_results = dense_results + sparse_results
 
 
 
 
75
 
76
  # Create LangChain Documents
77
  documents = [Document(page_content=doc.page_content, metadata=doc.metadata) for doc in combined_results]
 
82
  raise NotImplementedError
83
 
84
  def get_dense_db(elasticsearch_url, index_dense, embeddings):
85
+ # retriever cache
86
+ cached_embedder = CacheBackedEmbeddings.from_bytes_store(
87
+ embeddings, store,
88
+ namespace='sentence-transformer',
89
+ # query_embedding_store=store,
90
+ # query_embedding_cache=True
91
+ )
92
+
93
+ cached_embedder.query_embedding_store = store
94
+
95
  dense_db = ElasticVectorSearch(
96
  elasticsearch_url=elasticsearch_url,
97
  index_name=index_dense,
98
  embedding=embeddings,
99
+ # embedding=cached_embedder,
100
  )
101
  return dense_db
102