Upload 5 files
Browse files- black.py +131 -0
- log/output.log +0 -0
- rag_101/client.py +52 -0
- rag_101/rag.py +61 -0
- rag_101/retriever.py +160 -0
black.py
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
os.environ["HF_HOME"] = "weights"
|
4 |
+
os.environ["TORCH_HOME"] = "weights"
|
5 |
+
|
6 |
+
import gc
|
7 |
+
import re
|
8 |
+
import uuid
|
9 |
+
import textwrap
|
10 |
+
import subprocess
|
11 |
+
import nest_asyncio
|
12 |
+
from dotenv import load_dotenv
|
13 |
+
from IPython.display import Markdown, display
|
14 |
+
|
15 |
+
from llama_index.core import Settings
|
16 |
+
from llama_index.llms.ollama import Ollama
|
17 |
+
from llama_index.core import PromptTemplate
|
18 |
+
from llama_index.core import SimpleDirectoryReader
|
19 |
+
from llama_index.core.ingestion import IngestionPipeline
|
20 |
+
from llama_index.core import VectorStoreIndex
|
21 |
+
from llama_index.core.storage.storage_context import StorageContext
|
22 |
+
|
23 |
+
from langchain.embeddings import HuggingFaceEmbeddings
|
24 |
+
from llama_index.embeddings.langchain import LangchainEmbedding
|
25 |
+
|
26 |
+
from rag_101.retriever import (
|
27 |
+
load_embedding_model,
|
28 |
+
load_reranker_model
|
29 |
+
)
|
30 |
+
|
31 |
+
# allows nested access to the event loop
|
32 |
+
nest_asyncio.apply()
|
33 |
+
|
34 |
+
# setting up the llm
|
35 |
+
llm=Ollama(model="mistral", request_timeout=60.0)
|
36 |
+
|
37 |
+
# setting up the embedding model
|
38 |
+
lc_embedding_model = load_embedding_model()
|
39 |
+
embed_model = LangchainEmbedding(lc_embedding_model)
|
40 |
+
|
41 |
+
# utility functions
|
42 |
+
def parse_github_url(url):
|
43 |
+
pattern = r"https://github\.com/([^/]+)/([^/]+)"
|
44 |
+
match = re.match(pattern, url)
|
45 |
+
return match.groups() if match else (None, None)
|
46 |
+
|
47 |
+
def clone_github_repo(repo_url):
|
48 |
+
try:
|
49 |
+
print('Cloning the repo ...')
|
50 |
+
result = subprocess.run(["git", "clone", repo_url], check=True, text=True, capture_output=True)
|
51 |
+
except subprocess.CalledProcessError as e:
|
52 |
+
print(f"Failed to clone repository: {e}")
|
53 |
+
return None
|
54 |
+
|
55 |
+
|
56 |
+
def validate_owner_repo(owner, repo):
|
57 |
+
return bool(owner) and bool(repo)
|
58 |
+
|
59 |
+
# Setup a query engine
|
60 |
+
|
61 |
+
def setup_query_engine(github_url):
|
62 |
+
|
63 |
+
owner, repo = parse_github_url(github_url)
|
64 |
+
|
65 |
+
if validate_owner_repo(owner, repo):
|
66 |
+
# Clone the GitHub repo & save it in a directory
|
67 |
+
input_dir_path = f"{repo}"
|
68 |
+
|
69 |
+
if os.path.exists(input_dir_path):
|
70 |
+
pass
|
71 |
+
else:
|
72 |
+
clone_github_repo(github_url)
|
73 |
+
|
74 |
+
loader = SimpleDirectoryReader(
|
75 |
+
input_dir = input_dir_path,
|
76 |
+
required_exts=[".py", ".ipynb", ".js", ".ts", ".md"],
|
77 |
+
recursive=True
|
78 |
+
)
|
79 |
+
|
80 |
+
try:
|
81 |
+
docs = loader.load_data()
|
82 |
+
|
83 |
+
# ====== Create vector store and upload data ======
|
84 |
+
Settings.embed_model = embed_model
|
85 |
+
index = VectorStoreIndex.from_documents(docs, show_progress=True)
|
86 |
+
# TODO try async index creation for faster emebdding generation & persist it to memory!
|
87 |
+
# index = VectorStoreIndex(docs, use_async=True)
|
88 |
+
|
89 |
+
# ====== Setup a query engine ======
|
90 |
+
Settings.llm = llm
|
91 |
+
query_engine = index.as_query_engine(similarity_top_k=4)
|
92 |
+
|
93 |
+
# ====== Customise prompt template ======
|
94 |
+
qa_prompt_tmpl_str = (
|
95 |
+
"Context information is below.\n"
|
96 |
+
"---------------------\n"
|
97 |
+
"{context_str}\n"
|
98 |
+
"---------------------\n"
|
99 |
+
"Given the context information above I want you to think step by step to answer the query in a crisp manner, incase case you don't know the answer say 'I don't know!'.\n"
|
100 |
+
"Query: {query_str}\n"
|
101 |
+
"Answer: "
|
102 |
+
)
|
103 |
+
qa_prompt_tmpl = PromptTemplate(qa_prompt_tmpl_str)
|
104 |
+
|
105 |
+
query_engine.update_prompts(
|
106 |
+
{"response_synthesizer:text_qa_template": qa_prompt_tmpl}
|
107 |
+
)
|
108 |
+
|
109 |
+
if docs:
|
110 |
+
print("Data loaded successfully!!")
|
111 |
+
print("Ready to chat!!")
|
112 |
+
else:
|
113 |
+
print("No data found, check if the repository is not empty!")
|
114 |
+
|
115 |
+
return query_engine
|
116 |
+
|
117 |
+
except Exception as e:
|
118 |
+
print(f"An error occurred: {e}")
|
119 |
+
else:
|
120 |
+
print('Invalid github repo, try again!')
|
121 |
+
return None
|
122 |
+
|
123 |
+
# Provide url to the repository you want to chat with
|
124 |
+
github_url = "https://github.com/Aniket23160/Pose-Graph-SLAM"
|
125 |
+
|
126 |
+
query_engine = setup_query_engine(github_url=github_url)
|
127 |
+
print("----------------------------------------------------------------")
|
128 |
+
query='What is this repo about?'
|
129 |
+
print(f"Question: {query}")
|
130 |
+
response = query_engine.query(query)
|
131 |
+
print(f"Answer: {response}")
|
log/output.log
ADDED
File without changes
|
rag_101/client.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
from typing import List, Optional, Union
|
3 |
+
|
4 |
+
from langchain_community.chat_models import ChatOllama
|
5 |
+
from langchain_core.output_parsers import StrOutputParser
|
6 |
+
from langchain_core.prompts import ChatPromptTemplate
|
7 |
+
from retriever import (
|
8 |
+
create_parent_retriever,
|
9 |
+
load_embedding_model,
|
10 |
+
load_pdf,
|
11 |
+
load_reranker_model,
|
12 |
+
retrieve_context,
|
13 |
+
)
|
14 |
+
|
15 |
+
|
16 |
+
def main(
|
17 |
+
file: str = "example_data/2401.08406.pdf",
|
18 |
+
llm_name="mistral",
|
19 |
+
):
|
20 |
+
docs = load_pdf(files=file)
|
21 |
+
|
22 |
+
embedding_model = load_embedding_model()
|
23 |
+
retriever = create_parent_retriever(docs, embedding_model)
|
24 |
+
reranker_model = load_reranker_model()
|
25 |
+
|
26 |
+
llm = ChatOllama(model=llm_name)
|
27 |
+
prompt_template = ChatPromptTemplate.from_template(
|
28 |
+
(
|
29 |
+
"Please answer the following question based on the provided `context` that follows the question.\n"
|
30 |
+
"If you do not know the answer then just say 'I do not know'\n"
|
31 |
+
"question: {question}\n"
|
32 |
+
"context: ```{context}```\n"
|
33 |
+
)
|
34 |
+
)
|
35 |
+
chain = prompt_template | llm | StrOutputParser()
|
36 |
+
|
37 |
+
while True:
|
38 |
+
query = input("Ask question: ")
|
39 |
+
context = retrieve_context(
|
40 |
+
query, retriever=retriever, reranker_model=reranker_model
|
41 |
+
)[0]
|
42 |
+
print("LLM Response: ", end="")
|
43 |
+
for e in chain.stream({"context": context[0].page_content, "question": query}):
|
44 |
+
print(e, end="")
|
45 |
+
print()
|
46 |
+
time.sleep(0.1)
|
47 |
+
|
48 |
+
|
49 |
+
if __name__ == "__main__":
|
50 |
+
from jsonargparse import CLI
|
51 |
+
|
52 |
+
CLI(main)
|
rag_101/rag.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain.callbacks import FileCallbackHandler
|
2 |
+
from langchain_community.chat_models import ChatOllama
|
3 |
+
from langchain_core.output_parsers import StrOutputParser
|
4 |
+
from langchain_core.prompts import ChatPromptTemplate
|
5 |
+
from loguru import logger
|
6 |
+
|
7 |
+
from rag_101.retriever import (
|
8 |
+
RAGException,
|
9 |
+
create_parent_retriever,
|
10 |
+
load_embedding_model,
|
11 |
+
load_pdf,
|
12 |
+
load_reranker_model,
|
13 |
+
retrieve_context,
|
14 |
+
)
|
15 |
+
|
16 |
+
|
17 |
+
class RAGClient:
|
18 |
+
embedding_model = load_embedding_model()
|
19 |
+
reranker_model = load_reranker_model()
|
20 |
+
|
21 |
+
def __init__(self, files, model="mistral"):
|
22 |
+
docs = load_pdf(files=files)
|
23 |
+
self.retriever = create_parent_retriever(docs, self.embedding_model)
|
24 |
+
|
25 |
+
llm = ChatOllama(model=model)
|
26 |
+
prompt_template = ChatPromptTemplate.from_template(
|
27 |
+
(
|
28 |
+
"Please answer the following question based on the provided `context` that follows the question.\n"
|
29 |
+
"Think step by step before coming to answer. If you do not know the answer then just say 'I do not know'\n"
|
30 |
+
"question: {question}\n"
|
31 |
+
"context: ```{context}```\n"
|
32 |
+
)
|
33 |
+
)
|
34 |
+
self.chain = prompt_template | llm | StrOutputParser()
|
35 |
+
|
36 |
+
def stream(self, query: str) -> dict:
|
37 |
+
try:
|
38 |
+
context, similarity_score = self.retrieve_context(query)[0]
|
39 |
+
context = context.page_content
|
40 |
+
if similarity_score < 0.005:
|
41 |
+
context = "This context is not confident. " + context
|
42 |
+
except RAGException as e:
|
43 |
+
context, similarity_score = e.args[0], 0
|
44 |
+
logger.info(context)
|
45 |
+
for r in self.chain.stream({"context": context, "question": query}):
|
46 |
+
yield r
|
47 |
+
|
48 |
+
def retrieve_context(self, query: str):
|
49 |
+
return retrieve_context(
|
50 |
+
query, retriever=self.retriever, reranker_model=self.reranker_model
|
51 |
+
)
|
52 |
+
|
53 |
+
def generate(self, query: str) -> dict:
|
54 |
+
contexts = self.retrieve_context(query)
|
55 |
+
|
56 |
+
return {
|
57 |
+
"contexts": contexts,
|
58 |
+
"response": self.chain.invoke(
|
59 |
+
{"context": contexts[0][0].page_content, "question": query}
|
60 |
+
),
|
61 |
+
}
|
rag_101/retriever.py
ADDED
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
os.environ["HF_HOME"] = "weights"
|
4 |
+
os.environ["TORCH_HOME"] = "weights"
|
5 |
+
|
6 |
+
from typing import List, Optional, Union
|
7 |
+
|
8 |
+
from langchain.callbacks import FileCallbackHandler
|
9 |
+
from langchain.retrievers import ContextualCompressionRetriever, ParentDocumentRetriever
|
10 |
+
from langchain.retrievers.document_compressors import EmbeddingsFilter
|
11 |
+
from langchain.storage import InMemoryStore
|
12 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
13 |
+
from langchain_community.document_loaders import UnstructuredFileLoader
|
14 |
+
from langchain_community.embeddings import HuggingFaceBgeEmbeddings
|
15 |
+
from langchain_community.vectorstores import FAISS, Chroma
|
16 |
+
from langchain_core.documents import Document
|
17 |
+
from loguru import logger
|
18 |
+
from rich import print
|
19 |
+
from sentence_transformers import CrossEncoder
|
20 |
+
from unstructured.cleaners.core import clean_extra_whitespace, group_broken_paragraphs
|
21 |
+
|
22 |
+
logfile = "log/output.log"
|
23 |
+
logger.add(logfile, colorize=True, enqueue=True)
|
24 |
+
handler = FileCallbackHandler(logfile)
|
25 |
+
|
26 |
+
|
27 |
+
persist_directory = None
|
28 |
+
|
29 |
+
|
30 |
+
class RAGException(Exception):
|
31 |
+
def __init__(self, *args, **kwargs):
|
32 |
+
super().__init__(*args, **kwargs)
|
33 |
+
|
34 |
+
|
35 |
+
def rerank_docs(reranker_model, query, retrieved_docs):
|
36 |
+
query_and_docs = [(query, r.page_content) for r in retrieved_docs]
|
37 |
+
scores = reranker_model.predict(query_and_docs)
|
38 |
+
return sorted(list(zip(retrieved_docs, scores)), key=lambda x: x[1], reverse=True)
|
39 |
+
|
40 |
+
|
41 |
+
def load_pdf(
|
42 |
+
files: Union[str, List[str]] = "example_data/2401.08406.pdf"
|
43 |
+
) -> List[Document]:
|
44 |
+
if isinstance(files, str):
|
45 |
+
loader = UnstructuredFileLoader(
|
46 |
+
files,
|
47 |
+
post_processors=[clean_extra_whitespace, group_broken_paragraphs],
|
48 |
+
)
|
49 |
+
return loader.load()
|
50 |
+
|
51 |
+
loaders = [
|
52 |
+
UnstructuredFileLoader(
|
53 |
+
file,
|
54 |
+
post_processors=[clean_extra_whitespace, group_broken_paragraphs],
|
55 |
+
)
|
56 |
+
for file in files
|
57 |
+
]
|
58 |
+
docs = []
|
59 |
+
for loader in loaders:
|
60 |
+
docs.extend(
|
61 |
+
loader.load(),
|
62 |
+
)
|
63 |
+
return docs
|
64 |
+
|
65 |
+
|
66 |
+
def create_parent_retriever(
|
67 |
+
docs: List[Document], embeddings_model: HuggingFaceBgeEmbeddings()
|
68 |
+
):
|
69 |
+
parent_splitter = RecursiveCharacterTextSplitter(
|
70 |
+
separators=["\n\n\n", "\n\n"],
|
71 |
+
chunk_size=2000,
|
72 |
+
length_function=len,
|
73 |
+
is_separator_regex=False,
|
74 |
+
)
|
75 |
+
|
76 |
+
# This text splitter is used to create the child documents
|
77 |
+
child_splitter = RecursiveCharacterTextSplitter(
|
78 |
+
separators=["\n\n\n", "\n\n"],
|
79 |
+
chunk_size=1000,
|
80 |
+
chunk_overlap=300,
|
81 |
+
length_function=len,
|
82 |
+
is_separator_regex=False,
|
83 |
+
)
|
84 |
+
# The vectorstore to use to index the child chunks
|
85 |
+
vectorstore = Chroma(
|
86 |
+
collection_name="split_documents",
|
87 |
+
embedding_function=embeddings_model,
|
88 |
+
persist_directory=persist_directory,
|
89 |
+
)
|
90 |
+
# The storage layer for the parent documents
|
91 |
+
store = InMemoryStore()
|
92 |
+
retriever = ParentDocumentRetriever(
|
93 |
+
vectorstore=vectorstore,
|
94 |
+
docstore=store,
|
95 |
+
child_splitter=child_splitter,
|
96 |
+
parent_splitter=parent_splitter,
|
97 |
+
k=10,
|
98 |
+
)
|
99 |
+
retriever.add_documents(docs)
|
100 |
+
return retriever
|
101 |
+
|
102 |
+
|
103 |
+
def retrieve_context(query, retriever, reranker_model):
|
104 |
+
retrieved_docs = retriever.get_relevant_documents(query)
|
105 |
+
|
106 |
+
if len(retrieved_docs) == 0:
|
107 |
+
raise RAGException(
|
108 |
+
f"Couldn't retrieve any relevant document with the query `{query}`. Try modifying your question!"
|
109 |
+
)
|
110 |
+
reranked_docs = rerank_docs(
|
111 |
+
query=query, retrieved_docs=retrieved_docs, reranker_model=reranker_model
|
112 |
+
)
|
113 |
+
return reranked_docs
|
114 |
+
|
115 |
+
|
116 |
+
def load_embedding_model(
|
117 |
+
model_name: str = "BAAI/bge-large-en-v1.5", device: str = "cpu"
|
118 |
+
) -> HuggingFaceBgeEmbeddings:
|
119 |
+
model_kwargs = {"device": device}
|
120 |
+
encode_kwargs = {
|
121 |
+
"normalize_embeddings": True
|
122 |
+
} # set True to compute cosine similarity
|
123 |
+
embedding_model = HuggingFaceBgeEmbeddings(
|
124 |
+
model_name=model_name,
|
125 |
+
model_kwargs=model_kwargs,
|
126 |
+
encode_kwargs=encode_kwargs,
|
127 |
+
)
|
128 |
+
return embedding_model
|
129 |
+
|
130 |
+
|
131 |
+
def load_reranker_model(
|
132 |
+
reranker_model_name: str = "BAAI/bge-reranker-large", device: str = "cpu"
|
133 |
+
) -> CrossEncoder:
|
134 |
+
reranker_model = CrossEncoder(
|
135 |
+
model_name=reranker_model_name, max_length=512, device=device
|
136 |
+
)
|
137 |
+
return reranker_model
|
138 |
+
|
139 |
+
|
140 |
+
def main(
|
141 |
+
file: str = "example_data/2401.08406.pdf",
|
142 |
+
query: Optional[str] = None,
|
143 |
+
llm_name="mistral",
|
144 |
+
):
|
145 |
+
docs = load_pdf(files=file)
|
146 |
+
|
147 |
+
embedding_model = load_embedding_model()
|
148 |
+
retriever = create_parent_retriever(docs, embedding_model)
|
149 |
+
reranker_model = load_reranker_model()
|
150 |
+
|
151 |
+
context = retrieve_context(
|
152 |
+
query, retriever=retriever, reranker_model=reranker_model
|
153 |
+
)[0]
|
154 |
+
print("context:\n", context, "\n", "=" * 50, "\n")
|
155 |
+
|
156 |
+
|
157 |
+
if __name__ == "__main__":
|
158 |
+
from jsonargparse import CLI
|
159 |
+
|
160 |
+
CLI(main)
|