File size: 4,842 Bytes
df83264 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 |
import os
os.environ["HF_HOME"] = "weights"
os.environ["TORCH_HOME"] = "weights"
from typing import List, Optional, Union
from langchain.callbacks import FileCallbackHandler
from langchain.retrievers import ContextualCompressionRetriever, ParentDocumentRetriever
from langchain.retrievers.document_compressors import EmbeddingsFilter
from langchain.storage import InMemoryStore
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import UnstructuredFileLoader
from langchain_community.embeddings import HuggingFaceBgeEmbeddings
from langchain_community.vectorstores import FAISS, Chroma
from langchain_core.documents import Document
from loguru import logger
from rich import print
from sentence_transformers import CrossEncoder
from unstructured.cleaners.core import clean_extra_whitespace, group_broken_paragraphs
logfile = "log/output.log"
logger.add(logfile, colorize=True, enqueue=True)
handler = FileCallbackHandler(logfile)
persist_directory = None
class RAGException(Exception):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def rerank_docs(reranker_model, query, retrieved_docs):
query_and_docs = [(query, r.page_content) for r in retrieved_docs]
scores = reranker_model.predict(query_and_docs)
return sorted(list(zip(retrieved_docs, scores)), key=lambda x: x[1], reverse=True)
def load_pdf(
files: Union[str, List[str]] = "example_data/2401.08406.pdf"
) -> List[Document]:
if isinstance(files, str):
loader = UnstructuredFileLoader(
files,
post_processors=[clean_extra_whitespace, group_broken_paragraphs],
)
return loader.load()
loaders = [
UnstructuredFileLoader(
file,
post_processors=[clean_extra_whitespace, group_broken_paragraphs],
)
for file in files
]
docs = []
for loader in loaders:
docs.extend(
loader.load(),
)
return docs
def create_parent_retriever(
docs: List[Document], embeddings_model: HuggingFaceBgeEmbeddings()
):
parent_splitter = RecursiveCharacterTextSplitter(
separators=["\n\n\n", "\n\n"],
chunk_size=2000,
length_function=len,
is_separator_regex=False,
)
# This text splitter is used to create the child documents
child_splitter = RecursiveCharacterTextSplitter(
separators=["\n\n\n", "\n\n"],
chunk_size=1000,
chunk_overlap=300,
length_function=len,
is_separator_regex=False,
)
# The vectorstore to use to index the child chunks
vectorstore = Chroma(
collection_name="split_documents",
embedding_function=embeddings_model,
persist_directory=persist_directory,
)
# The storage layer for the parent documents
store = InMemoryStore()
retriever = ParentDocumentRetriever(
vectorstore=vectorstore,
docstore=store,
child_splitter=child_splitter,
parent_splitter=parent_splitter,
k=10,
)
retriever.add_documents(docs)
return retriever
def retrieve_context(query, retriever, reranker_model):
retrieved_docs = retriever.get_relevant_documents(query)
if len(retrieved_docs) == 0:
raise RAGException(
f"Couldn't retrieve any relevant document with the query `{query}`. Try modifying your question!"
)
reranked_docs = rerank_docs(
query=query, retrieved_docs=retrieved_docs, reranker_model=reranker_model
)
return reranked_docs
def load_embedding_model(
model_name: str = "BAAI/bge-large-en-v1.5", device: str = "cpu"
) -> HuggingFaceBgeEmbeddings:
model_kwargs = {"device": device}
encode_kwargs = {
"normalize_embeddings": True
} # set True to compute cosine similarity
embedding_model = HuggingFaceBgeEmbeddings(
model_name=model_name,
model_kwargs=model_kwargs,
encode_kwargs=encode_kwargs,
)
return embedding_model
def load_reranker_model(
reranker_model_name: str = "BAAI/bge-reranker-large", device: str = "cpu"
) -> CrossEncoder:
reranker_model = CrossEncoder(
model_name=reranker_model_name, max_length=512, device=device
)
return reranker_model
def main(
file: str = "example_data/2401.08406.pdf",
query: Optional[str] = None,
llm_name="mistral",
):
docs = load_pdf(files=file)
embedding_model = load_embedding_model()
retriever = create_parent_retriever(docs, embedding_model)
reranker_model = load_reranker_model()
context = retrieve_context(
query, retriever=retriever, reranker_model=reranker_model
)[0]
print("context:\n", context, "\n", "=" * 50, "\n")
if __name__ == "__main__":
from jsonargparse import CLI
CLI(main)
|