Spaces:
Runtime error
Runtime error
Commit
·
8b9c9ff
1
Parent(s):
bf1e0a0
#perf added hybrid search using bm25 + semantic, minor change to text, splitter, and retrieval hyperparameters
Browse files- ai_generate.py +46 -13
ai_generate.py
CHANGED
@@ -1,5 +1,11 @@
|
|
1 |
import gc
|
2 |
import os
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
from langchain_community.document_loaders import PyMuPDFLoader
|
4 |
from langchain_core.documents import Document
|
5 |
from langchain_community.embeddings.sentence_transformer import (
|
@@ -15,14 +21,12 @@ from langchain_anthropic import ChatAnthropic
|
|
15 |
from dotenv import load_dotenv
|
16 |
from langchain_core.output_parsers import XMLOutputParser
|
17 |
from langchain.prompts import ChatPromptTemplate
|
18 |
-
import re
|
19 |
-
import numpy as np
|
20 |
-
import torch
|
21 |
-
import bm25s
|
22 |
from langchain_community.cross_encoders import HuggingFaceCrossEncoder
|
23 |
from langchain.retrievers import ContextualCompressionRetriever
|
24 |
from langchain.retrievers.document_compressors import CrossEncoderReranker
|
25 |
from langchain_core.messages import HumanMessage
|
|
|
|
|
26 |
|
27 |
load_dotenv()
|
28 |
|
@@ -33,8 +37,10 @@ os.environ["GLOG_minloglevel"] = "2"
|
|
33 |
# RAG parameters
|
34 |
CHUNK_SIZE = 1024
|
35 |
CHUNK_OVERLAP = CHUNK_SIZE // 8
|
36 |
-
K =
|
37 |
FETCH_K = 50
|
|
|
|
|
38 |
|
39 |
model_kwargs = {"device": "cuda:1"}
|
40 |
print("Loading embedding and reranker models...")
|
@@ -44,7 +50,7 @@ embedding_function = SentenceTransformerEmbeddings(
|
|
44 |
# "sentence-transformers/all-MiniLM-L6-v2"
|
45 |
# "mixedbread-ai/mxbai-embed-large-v1"
|
46 |
reranker = HuggingFaceCrossEncoder(model_name="BAAI/bge-reranker-base", model_kwargs=model_kwargs)
|
47 |
-
compressor = CrossEncoderReranker(model=reranker, top_n=
|
48 |
|
49 |
llm_model_translation = {
|
50 |
"LLaMA 3": "llama3-70b-8192",
|
@@ -212,7 +218,30 @@ def load_llm(model: str, api_key: str, temperature: float = 1.0, max_length: int
|
|
212 |
|
213 |
def create_db_with_langchain(path: list[str], url_content: dict, query: str):
|
214 |
all_docs = []
|
215 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
216 |
if path:
|
217 |
for file in path:
|
218 |
loader = PyMuPDFLoader(file)
|
@@ -244,17 +273,19 @@ def create_db_with_langchain(path: list[str], url_content: dict, query: str):
|
|
244 |
for idx, doc in enumerate(all_docs):
|
245 |
print(f"Doc: {idx} | Length = {len(doc.page_content)}")
|
246 |
|
|
|
|
|
|
|
247 |
assert len(all_docs) > 0, "No PDFs or scrapped data provided"
|
248 |
db = Chroma.from_documents(all_docs, embedding_function)
|
249 |
torch.cuda.empty_cache()
|
250 |
gc.collect()
|
251 |
-
return db
|
252 |
|
253 |
|
254 |
def pretty_print_docs(docs):
|
255 |
print(f"\n{'-' * 100}\n".join([f"Document {i+1}:\n\n" + d.page_content for i, d in enumerate(docs)]))
|
256 |
|
257 |
-
|
258 |
def generate_rag(
|
259 |
prompt: str,
|
260 |
input_role: str,
|
@@ -275,12 +306,14 @@ def generate_rag(
|
|
275 |
|
276 |
query = llm_wrapper(input_role, topic, context, model="OpenAI GPT 4o", task_type="rag", temperature=0.7)
|
277 |
print("### Query: ", query)
|
278 |
-
db = create_db_with_langchain(path, url_content, query)
|
279 |
retriever = db.as_retriever(search_type="mmr", search_kwargs={"k": K, "fetch_k": FETCH_K, "lambda_mult": 0.75})
|
280 |
-
|
281 |
-
|
282 |
-
compression_retriever = ContextualCompressionRetriever(base_compressor=compressor, base_retriever=
|
283 |
docs = compression_retriever.invoke(query)
|
|
|
|
|
284 |
print(pretty_print_docs(docs))
|
285 |
|
286 |
formatted_docs = format_docs_xml(docs)
|
|
|
1 |
import gc
|
2 |
import os
|
3 |
+
import time
|
4 |
+
import re
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
import bm25s
|
8 |
+
|
9 |
from langchain_community.document_loaders import PyMuPDFLoader
|
10 |
from langchain_core.documents import Document
|
11 |
from langchain_community.embeddings.sentence_transformer import (
|
|
|
21 |
from dotenv import load_dotenv
|
22 |
from langchain_core.output_parsers import XMLOutputParser
|
23 |
from langchain.prompts import ChatPromptTemplate
|
|
|
|
|
|
|
|
|
24 |
from langchain_community.cross_encoders import HuggingFaceCrossEncoder
|
25 |
from langchain.retrievers import ContextualCompressionRetriever
|
26 |
from langchain.retrievers.document_compressors import CrossEncoderReranker
|
27 |
from langchain_core.messages import HumanMessage
|
28 |
+
from langchain.retrievers import EnsembleRetriever
|
29 |
+
from langchain_community.retrievers import BM25Retriever
|
30 |
|
31 |
load_dotenv()
|
32 |
|
|
|
37 |
# RAG parameters
|
38 |
CHUNK_SIZE = 1024
|
39 |
CHUNK_OVERLAP = CHUNK_SIZE // 8
|
40 |
+
K = 20 # number of chunks to retrieve from semantic search
|
41 |
FETCH_K = 50
|
42 |
+
N_BM25 = 20 # number of chunks to retrieve from keyword search
|
43 |
+
TOP_N = 10 # final number of chunks to keep
|
44 |
|
45 |
model_kwargs = {"device": "cuda:1"}
|
46 |
print("Loading embedding and reranker models...")
|
|
|
50 |
# "sentence-transformers/all-MiniLM-L6-v2"
|
51 |
# "mixedbread-ai/mxbai-embed-large-v1"
|
52 |
reranker = HuggingFaceCrossEncoder(model_name="BAAI/bge-reranker-base", model_kwargs=model_kwargs)
|
53 |
+
compressor = CrossEncoderReranker(model=reranker, top_n=TOP_N)
|
54 |
|
55 |
llm_model_translation = {
|
56 |
"LLaMA 3": "llama3-70b-8192",
|
|
|
218 |
|
219 |
def create_db_with_langchain(path: list[str], url_content: dict, query: str):
|
220 |
all_docs = []
|
221 |
+
|
222 |
+
text_splitter = RecursiveCharacterTextSplitter(
|
223 |
+
chunk_size=CHUNK_SIZE,
|
224 |
+
chunk_overlap=CHUNK_OVERLAP,
|
225 |
+
separators=[
|
226 |
+
"\n\n",
|
227 |
+
"\n",
|
228 |
+
".",
|
229 |
+
"\uff0e", # Fullwidth full stop
|
230 |
+
"\u3002", # Ideographic full stop
|
231 |
+
"?",
|
232 |
+
"!",
|
233 |
+
",",
|
234 |
+
"\uff0c", # Fullwidth comma
|
235 |
+
"\u3001", # Ideographic comma
|
236 |
+
" ",
|
237 |
+
"\u200B", # Zero-width space
|
238 |
+
"",
|
239 |
+
],
|
240 |
+
keep_separator=True,
|
241 |
+
is_separator_regex=False,
|
242 |
+
length_function=len,
|
243 |
+
add_start_index=False,
|
244 |
+
)
|
245 |
if path:
|
246 |
for file in path:
|
247 |
loader = PyMuPDFLoader(file)
|
|
|
273 |
for idx, doc in enumerate(all_docs):
|
274 |
print(f"Doc: {idx} | Length = {len(doc.page_content)}")
|
275 |
|
276 |
+
bm25_retriever = BM25Retriever.from_documents(all_docs)
|
277 |
+
bm25_retriever.k = N_BM25
|
278 |
+
|
279 |
assert len(all_docs) > 0, "No PDFs or scrapped data provided"
|
280 |
db = Chroma.from_documents(all_docs, embedding_function)
|
281 |
torch.cuda.empty_cache()
|
282 |
gc.collect()
|
283 |
+
return db, bm25_retriever
|
284 |
|
285 |
|
286 |
def pretty_print_docs(docs):
|
287 |
print(f"\n{'-' * 100}\n".join([f"Document {i+1}:\n\n" + d.page_content for i, d in enumerate(docs)]))
|
288 |
|
|
|
289 |
def generate_rag(
|
290 |
prompt: str,
|
291 |
input_role: str,
|
|
|
306 |
|
307 |
query = llm_wrapper(input_role, topic, context, model="OpenAI GPT 4o", task_type="rag", temperature=0.7)
|
308 |
print("### Query: ", query)
|
309 |
+
db, bm25_retriever = create_db_with_langchain(path, url_content, query)
|
310 |
retriever = db.as_retriever(search_type="mmr", search_kwargs={"k": K, "fetch_k": FETCH_K, "lambda_mult": 0.75})
|
311 |
+
t0 = time.time()
|
312 |
+
ensemble_retriever = EnsembleRetriever(retrievers=[bm25_retriever, retriever], weights=[0.4, 0.6])
|
313 |
+
compression_retriever = ContextualCompressionRetriever(base_compressor=compressor, base_retriever=ensemble_retriever)
|
314 |
docs = compression_retriever.invoke(query)
|
315 |
+
t1 = time.time()
|
316 |
+
print(f"Time for retrieval : {t1 - t0:.2f}s")
|
317 |
print(pretty_print_docs(docs))
|
318 |
|
319 |
formatted_docs = format_docs_xml(docs)
|