eljanmahammadli commited on
Commit
8b9c9ff
·
1 Parent(s): bf1e0a0

#perf added hybrid search using bm25 + semantic, minor change to text, splitter, and retrieval hyperparameters

Browse files
Files changed (1) hide show
  1. ai_generate.py +46 -13
ai_generate.py CHANGED
@@ -1,5 +1,11 @@
1
  import gc
2
  import os
 
 
 
 
 
 
3
  from langchain_community.document_loaders import PyMuPDFLoader
4
  from langchain_core.documents import Document
5
  from langchain_community.embeddings.sentence_transformer import (
@@ -15,14 +21,12 @@ from langchain_anthropic import ChatAnthropic
15
  from dotenv import load_dotenv
16
  from langchain_core.output_parsers import XMLOutputParser
17
  from langchain.prompts import ChatPromptTemplate
18
- import re
19
- import numpy as np
20
- import torch
21
- import bm25s
22
  from langchain_community.cross_encoders import HuggingFaceCrossEncoder
23
  from langchain.retrievers import ContextualCompressionRetriever
24
  from langchain.retrievers.document_compressors import CrossEncoderReranker
25
  from langchain_core.messages import HumanMessage
 
 
26
 
27
  load_dotenv()
28
 
@@ -33,8 +37,10 @@ os.environ["GLOG_minloglevel"] = "2"
33
  # RAG parameters
34
  CHUNK_SIZE = 1024
35
  CHUNK_OVERLAP = CHUNK_SIZE // 8
36
- K = 10
37
  FETCH_K = 50
 
 
38
 
39
  model_kwargs = {"device": "cuda:1"}
40
  print("Loading embedding and reranker models...")
@@ -44,7 +50,7 @@ embedding_function = SentenceTransformerEmbeddings(
44
  # "sentence-transformers/all-MiniLM-L6-v2"
45
  # "mixedbread-ai/mxbai-embed-large-v1"
46
  reranker = HuggingFaceCrossEncoder(model_name="BAAI/bge-reranker-base", model_kwargs=model_kwargs)
47
- compressor = CrossEncoderReranker(model=reranker, top_n=K)
48
 
49
  llm_model_translation = {
50
  "LLaMA 3": "llama3-70b-8192",
@@ -212,7 +218,30 @@ def load_llm(model: str, api_key: str, temperature: float = 1.0, max_length: int
212
 
213
  def create_db_with_langchain(path: list[str], url_content: dict, query: str):
214
  all_docs = []
215
- text_splitter = RecursiveCharacterTextSplitter(chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
216
  if path:
217
  for file in path:
218
  loader = PyMuPDFLoader(file)
@@ -244,17 +273,19 @@ def create_db_with_langchain(path: list[str], url_content: dict, query: str):
244
  for idx, doc in enumerate(all_docs):
245
  print(f"Doc: {idx} | Length = {len(doc.page_content)}")
246
 
 
 
 
247
  assert len(all_docs) > 0, "No PDFs or scrapped data provided"
248
  db = Chroma.from_documents(all_docs, embedding_function)
249
  torch.cuda.empty_cache()
250
  gc.collect()
251
- return db
252
 
253
 
254
  def pretty_print_docs(docs):
255
  print(f"\n{'-' * 100}\n".join([f"Document {i+1}:\n\n" + d.page_content for i, d in enumerate(docs)]))
256
 
257
-
258
  def generate_rag(
259
  prompt: str,
260
  input_role: str,
@@ -275,12 +306,14 @@ def generate_rag(
275
 
276
  query = llm_wrapper(input_role, topic, context, model="OpenAI GPT 4o", task_type="rag", temperature=0.7)
277
  print("### Query: ", query)
278
- db = create_db_with_langchain(path, url_content, query)
279
  retriever = db.as_retriever(search_type="mmr", search_kwargs={"k": K, "fetch_k": FETCH_K, "lambda_mult": 0.75})
280
-
281
- # docs = retriever.get_relevant_documents(query)
282
- compression_retriever = ContextualCompressionRetriever(base_compressor=compressor, base_retriever=retriever)
283
  docs = compression_retriever.invoke(query)
 
 
284
  print(pretty_print_docs(docs))
285
 
286
  formatted_docs = format_docs_xml(docs)
 
1
  import gc
2
  import os
3
+ import time
4
+ import re
5
+ import numpy as np
6
+ import torch
7
+ import bm25s
8
+
9
  from langchain_community.document_loaders import PyMuPDFLoader
10
  from langchain_core.documents import Document
11
  from langchain_community.embeddings.sentence_transformer import (
 
21
  from dotenv import load_dotenv
22
  from langchain_core.output_parsers import XMLOutputParser
23
  from langchain.prompts import ChatPromptTemplate
 
 
 
 
24
  from langchain_community.cross_encoders import HuggingFaceCrossEncoder
25
  from langchain.retrievers import ContextualCompressionRetriever
26
  from langchain.retrievers.document_compressors import CrossEncoderReranker
27
  from langchain_core.messages import HumanMessage
28
+ from langchain.retrievers import EnsembleRetriever
29
+ from langchain_community.retrievers import BM25Retriever
30
 
31
  load_dotenv()
32
 
 
37
  # RAG parameters
38
  CHUNK_SIZE = 1024
39
  CHUNK_OVERLAP = CHUNK_SIZE // 8
40
+ K = 20 # number of chunks to retrieve from semantic search
41
  FETCH_K = 50
42
+ N_BM25 = 20 # number of chunks to retrieve from keyword search
43
+ TOP_N = 10 # final number of chunks to keep
44
 
45
  model_kwargs = {"device": "cuda:1"}
46
  print("Loading embedding and reranker models...")
 
50
  # "sentence-transformers/all-MiniLM-L6-v2"
51
  # "mixedbread-ai/mxbai-embed-large-v1"
52
  reranker = HuggingFaceCrossEncoder(model_name="BAAI/bge-reranker-base", model_kwargs=model_kwargs)
53
+ compressor = CrossEncoderReranker(model=reranker, top_n=TOP_N)
54
 
55
  llm_model_translation = {
56
  "LLaMA 3": "llama3-70b-8192",
 
218
 
219
  def create_db_with_langchain(path: list[str], url_content: dict, query: str):
220
  all_docs = []
221
+
222
+ text_splitter = RecursiveCharacterTextSplitter(
223
+ chunk_size=CHUNK_SIZE,
224
+ chunk_overlap=CHUNK_OVERLAP,
225
+ separators=[
226
+ "\n\n",
227
+ "\n",
228
+ ".",
229
+ "\uff0e", # Fullwidth full stop
230
+ "\u3002", # Ideographic full stop
231
+ "?",
232
+ "!",
233
+ ",",
234
+ "\uff0c", # Fullwidth comma
235
+ "\u3001", # Ideographic comma
236
+ " ",
237
+ "\u200B", # Zero-width space
238
+ "",
239
+ ],
240
+ keep_separator=True,
241
+ is_separator_regex=False,
242
+ length_function=len,
243
+ add_start_index=False,
244
+ )
245
  if path:
246
  for file in path:
247
  loader = PyMuPDFLoader(file)
 
273
  for idx, doc in enumerate(all_docs):
274
  print(f"Doc: {idx} | Length = {len(doc.page_content)}")
275
 
276
+ bm25_retriever = BM25Retriever.from_documents(all_docs)
277
+ bm25_retriever.k = N_BM25
278
+
279
  assert len(all_docs) > 0, "No PDFs or scrapped data provided"
280
  db = Chroma.from_documents(all_docs, embedding_function)
281
  torch.cuda.empty_cache()
282
  gc.collect()
283
+ return db, bm25_retriever
284
 
285
 
286
  def pretty_print_docs(docs):
287
  print(f"\n{'-' * 100}\n".join([f"Document {i+1}:\n\n" + d.page_content for i, d in enumerate(docs)]))
288
 
 
289
  def generate_rag(
290
  prompt: str,
291
  input_role: str,
 
306
 
307
  query = llm_wrapper(input_role, topic, context, model="OpenAI GPT 4o", task_type="rag", temperature=0.7)
308
  print("### Query: ", query)
309
+ db, bm25_retriever = create_db_with_langchain(path, url_content, query)
310
  retriever = db.as_retriever(search_type="mmr", search_kwargs={"k": K, "fetch_k": FETCH_K, "lambda_mult": 0.75})
311
+ t0 = time.time()
312
+ ensemble_retriever = EnsembleRetriever(retrievers=[bm25_retriever, retriever], weights=[0.4, 0.6])
313
+ compression_retriever = ContextualCompressionRetriever(base_compressor=compressor, base_retriever=ensemble_retriever)
314
  docs = compression_retriever.invoke(query)
315
+ t1 = time.time()
316
+ print(f"Time for retrieval : {t1 - t0:.2f}s")
317
  print(pretty_print_docs(docs))
318
 
319
  formatted_docs = format_docs_xml(docs)