anikettty commited on
Commit
df83264
·
verified ·
1 Parent(s): d20d4dc

Upload 5 files

Browse files
Files changed (5) hide show
  1. black.py +131 -0
  2. log/output.log +0 -0
  3. rag_101/client.py +52 -0
  4. rag_101/rag.py +61 -0
  5. rag_101/retriever.py +160 -0
black.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ os.environ["HF_HOME"] = "weights"
4
+ os.environ["TORCH_HOME"] = "weights"
5
+
6
+ import gc
7
+ import re
8
+ import uuid
9
+ import textwrap
10
+ import subprocess
11
+ import nest_asyncio
12
+ from dotenv import load_dotenv
13
+ from IPython.display import Markdown, display
14
+
15
+ from llama_index.core import Settings
16
+ from llama_index.llms.ollama import Ollama
17
+ from llama_index.core import PromptTemplate
18
+ from llama_index.core import SimpleDirectoryReader
19
+ from llama_index.core.ingestion import IngestionPipeline
20
+ from llama_index.core import VectorStoreIndex
21
+ from llama_index.core.storage.storage_context import StorageContext
22
+
23
+ from langchain.embeddings import HuggingFaceEmbeddings
24
+ from llama_index.embeddings.langchain import LangchainEmbedding
25
+
26
+ from rag_101.retriever import (
27
+ load_embedding_model,
28
+ load_reranker_model
29
+ )
30
+
31
+ # allows nested access to the event loop
32
+ nest_asyncio.apply()
33
+
34
+ # setting up the llm
35
+ llm=Ollama(model="mistral", request_timeout=60.0)
36
+
37
+ # setting up the embedding model
38
+ lc_embedding_model = load_embedding_model()
39
+ embed_model = LangchainEmbedding(lc_embedding_model)
40
+
41
+ # utility functions
42
+ def parse_github_url(url):
43
+ pattern = r"https://github\.com/([^/]+)/([^/]+)"
44
+ match = re.match(pattern, url)
45
+ return match.groups() if match else (None, None)
46
+
47
+ def clone_github_repo(repo_url):
48
+ try:
49
+ print('Cloning the repo ...')
50
+ result = subprocess.run(["git", "clone", repo_url], check=True, text=True, capture_output=True)
51
+ except subprocess.CalledProcessError as e:
52
+ print(f"Failed to clone repository: {e}")
53
+ return None
54
+
55
+
56
+ def validate_owner_repo(owner, repo):
57
+ return bool(owner) and bool(repo)
58
+
59
+ # Setup a query engine
60
+
61
+ def setup_query_engine(github_url):
62
+
63
+ owner, repo = parse_github_url(github_url)
64
+
65
+ if validate_owner_repo(owner, repo):
66
+ # Clone the GitHub repo & save it in a directory
67
+ input_dir_path = f"{repo}"
68
+
69
+ if os.path.exists(input_dir_path):
70
+ pass
71
+ else:
72
+ clone_github_repo(github_url)
73
+
74
+ loader = SimpleDirectoryReader(
75
+ input_dir = input_dir_path,
76
+ required_exts=[".py", ".ipynb", ".js", ".ts", ".md"],
77
+ recursive=True
78
+ )
79
+
80
+ try:
81
+ docs = loader.load_data()
82
+
83
+ # ====== Create vector store and upload data ======
84
+ Settings.embed_model = embed_model
85
+ index = VectorStoreIndex.from_documents(docs, show_progress=True)
86
+ # TODO try async index creation for faster emebdding generation & persist it to memory!
87
+ # index = VectorStoreIndex(docs, use_async=True)
88
+
89
+ # ====== Setup a query engine ======
90
+ Settings.llm = llm
91
+ query_engine = index.as_query_engine(similarity_top_k=4)
92
+
93
+ # ====== Customise prompt template ======
94
+ qa_prompt_tmpl_str = (
95
+ "Context information is below.\n"
96
+ "---------------------\n"
97
+ "{context_str}\n"
98
+ "---------------------\n"
99
+ "Given the context information above I want you to think step by step to answer the query in a crisp manner, incase case you don't know the answer say 'I don't know!'.\n"
100
+ "Query: {query_str}\n"
101
+ "Answer: "
102
+ )
103
+ qa_prompt_tmpl = PromptTemplate(qa_prompt_tmpl_str)
104
+
105
+ query_engine.update_prompts(
106
+ {"response_synthesizer:text_qa_template": qa_prompt_tmpl}
107
+ )
108
+
109
+ if docs:
110
+ print("Data loaded successfully!!")
111
+ print("Ready to chat!!")
112
+ else:
113
+ print("No data found, check if the repository is not empty!")
114
+
115
+ return query_engine
116
+
117
+ except Exception as e:
118
+ print(f"An error occurred: {e}")
119
+ else:
120
+ print('Invalid github repo, try again!')
121
+ return None
122
+
123
+ # Provide url to the repository you want to chat with
124
+ github_url = "https://github.com/Aniket23160/Pose-Graph-SLAM"
125
+
126
+ query_engine = setup_query_engine(github_url=github_url)
127
+ print("----------------------------------------------------------------")
128
+ query='What is this repo about?'
129
+ print(f"Question: {query}")
130
+ response = query_engine.query(query)
131
+ print(f"Answer: {response}")
log/output.log ADDED
File without changes
rag_101/client.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ from typing import List, Optional, Union
3
+
4
+ from langchain_community.chat_models import ChatOllama
5
+ from langchain_core.output_parsers import StrOutputParser
6
+ from langchain_core.prompts import ChatPromptTemplate
7
+ from retriever import (
8
+ create_parent_retriever,
9
+ load_embedding_model,
10
+ load_pdf,
11
+ load_reranker_model,
12
+ retrieve_context,
13
+ )
14
+
15
+
16
+ def main(
17
+ file: str = "example_data/2401.08406.pdf",
18
+ llm_name="mistral",
19
+ ):
20
+ docs = load_pdf(files=file)
21
+
22
+ embedding_model = load_embedding_model()
23
+ retriever = create_parent_retriever(docs, embedding_model)
24
+ reranker_model = load_reranker_model()
25
+
26
+ llm = ChatOllama(model=llm_name)
27
+ prompt_template = ChatPromptTemplate.from_template(
28
+ (
29
+ "Please answer the following question based on the provided `context` that follows the question.\n"
30
+ "If you do not know the answer then just say 'I do not know'\n"
31
+ "question: {question}\n"
32
+ "context: ```{context}```\n"
33
+ )
34
+ )
35
+ chain = prompt_template | llm | StrOutputParser()
36
+
37
+ while True:
38
+ query = input("Ask question: ")
39
+ context = retrieve_context(
40
+ query, retriever=retriever, reranker_model=reranker_model
41
+ )[0]
42
+ print("LLM Response: ", end="")
43
+ for e in chain.stream({"context": context[0].page_content, "question": query}):
44
+ print(e, end="")
45
+ print()
46
+ time.sleep(0.1)
47
+
48
+
49
+ if __name__ == "__main__":
50
+ from jsonargparse import CLI
51
+
52
+ CLI(main)
rag_101/rag.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.callbacks import FileCallbackHandler
2
+ from langchain_community.chat_models import ChatOllama
3
+ from langchain_core.output_parsers import StrOutputParser
4
+ from langchain_core.prompts import ChatPromptTemplate
5
+ from loguru import logger
6
+
7
+ from rag_101.retriever import (
8
+ RAGException,
9
+ create_parent_retriever,
10
+ load_embedding_model,
11
+ load_pdf,
12
+ load_reranker_model,
13
+ retrieve_context,
14
+ )
15
+
16
+
17
+ class RAGClient:
18
+ embedding_model = load_embedding_model()
19
+ reranker_model = load_reranker_model()
20
+
21
+ def __init__(self, files, model="mistral"):
22
+ docs = load_pdf(files=files)
23
+ self.retriever = create_parent_retriever(docs, self.embedding_model)
24
+
25
+ llm = ChatOllama(model=model)
26
+ prompt_template = ChatPromptTemplate.from_template(
27
+ (
28
+ "Please answer the following question based on the provided `context` that follows the question.\n"
29
+ "Think step by step before coming to answer. If you do not know the answer then just say 'I do not know'\n"
30
+ "question: {question}\n"
31
+ "context: ```{context}```\n"
32
+ )
33
+ )
34
+ self.chain = prompt_template | llm | StrOutputParser()
35
+
36
+ def stream(self, query: str) -> dict:
37
+ try:
38
+ context, similarity_score = self.retrieve_context(query)[0]
39
+ context = context.page_content
40
+ if similarity_score < 0.005:
41
+ context = "This context is not confident. " + context
42
+ except RAGException as e:
43
+ context, similarity_score = e.args[0], 0
44
+ logger.info(context)
45
+ for r in self.chain.stream({"context": context, "question": query}):
46
+ yield r
47
+
48
+ def retrieve_context(self, query: str):
49
+ return retrieve_context(
50
+ query, retriever=self.retriever, reranker_model=self.reranker_model
51
+ )
52
+
53
+ def generate(self, query: str) -> dict:
54
+ contexts = self.retrieve_context(query)
55
+
56
+ return {
57
+ "contexts": contexts,
58
+ "response": self.chain.invoke(
59
+ {"context": contexts[0][0].page_content, "question": query}
60
+ ),
61
+ }
rag_101/retriever.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ os.environ["HF_HOME"] = "weights"
4
+ os.environ["TORCH_HOME"] = "weights"
5
+
6
+ from typing import List, Optional, Union
7
+
8
+ from langchain.callbacks import FileCallbackHandler
9
+ from langchain.retrievers import ContextualCompressionRetriever, ParentDocumentRetriever
10
+ from langchain.retrievers.document_compressors import EmbeddingsFilter
11
+ from langchain.storage import InMemoryStore
12
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
13
+ from langchain_community.document_loaders import UnstructuredFileLoader
14
+ from langchain_community.embeddings import HuggingFaceBgeEmbeddings
15
+ from langchain_community.vectorstores import FAISS, Chroma
16
+ from langchain_core.documents import Document
17
+ from loguru import logger
18
+ from rich import print
19
+ from sentence_transformers import CrossEncoder
20
+ from unstructured.cleaners.core import clean_extra_whitespace, group_broken_paragraphs
21
+
22
+ logfile = "log/output.log"
23
+ logger.add(logfile, colorize=True, enqueue=True)
24
+ handler = FileCallbackHandler(logfile)
25
+
26
+
27
+ persist_directory = None
28
+
29
+
30
+ class RAGException(Exception):
31
+ def __init__(self, *args, **kwargs):
32
+ super().__init__(*args, **kwargs)
33
+
34
+
35
+ def rerank_docs(reranker_model, query, retrieved_docs):
36
+ query_and_docs = [(query, r.page_content) for r in retrieved_docs]
37
+ scores = reranker_model.predict(query_and_docs)
38
+ return sorted(list(zip(retrieved_docs, scores)), key=lambda x: x[1], reverse=True)
39
+
40
+
41
+ def load_pdf(
42
+ files: Union[str, List[str]] = "example_data/2401.08406.pdf"
43
+ ) -> List[Document]:
44
+ if isinstance(files, str):
45
+ loader = UnstructuredFileLoader(
46
+ files,
47
+ post_processors=[clean_extra_whitespace, group_broken_paragraphs],
48
+ )
49
+ return loader.load()
50
+
51
+ loaders = [
52
+ UnstructuredFileLoader(
53
+ file,
54
+ post_processors=[clean_extra_whitespace, group_broken_paragraphs],
55
+ )
56
+ for file in files
57
+ ]
58
+ docs = []
59
+ for loader in loaders:
60
+ docs.extend(
61
+ loader.load(),
62
+ )
63
+ return docs
64
+
65
+
66
+ def create_parent_retriever(
67
+ docs: List[Document], embeddings_model: HuggingFaceBgeEmbeddings()
68
+ ):
69
+ parent_splitter = RecursiveCharacterTextSplitter(
70
+ separators=["\n\n\n", "\n\n"],
71
+ chunk_size=2000,
72
+ length_function=len,
73
+ is_separator_regex=False,
74
+ )
75
+
76
+ # This text splitter is used to create the child documents
77
+ child_splitter = RecursiveCharacterTextSplitter(
78
+ separators=["\n\n\n", "\n\n"],
79
+ chunk_size=1000,
80
+ chunk_overlap=300,
81
+ length_function=len,
82
+ is_separator_regex=False,
83
+ )
84
+ # The vectorstore to use to index the child chunks
85
+ vectorstore = Chroma(
86
+ collection_name="split_documents",
87
+ embedding_function=embeddings_model,
88
+ persist_directory=persist_directory,
89
+ )
90
+ # The storage layer for the parent documents
91
+ store = InMemoryStore()
92
+ retriever = ParentDocumentRetriever(
93
+ vectorstore=vectorstore,
94
+ docstore=store,
95
+ child_splitter=child_splitter,
96
+ parent_splitter=parent_splitter,
97
+ k=10,
98
+ )
99
+ retriever.add_documents(docs)
100
+ return retriever
101
+
102
+
103
+ def retrieve_context(query, retriever, reranker_model):
104
+ retrieved_docs = retriever.get_relevant_documents(query)
105
+
106
+ if len(retrieved_docs) == 0:
107
+ raise RAGException(
108
+ f"Couldn't retrieve any relevant document with the query `{query}`. Try modifying your question!"
109
+ )
110
+ reranked_docs = rerank_docs(
111
+ query=query, retrieved_docs=retrieved_docs, reranker_model=reranker_model
112
+ )
113
+ return reranked_docs
114
+
115
+
116
+ def load_embedding_model(
117
+ model_name: str = "BAAI/bge-large-en-v1.5", device: str = "cpu"
118
+ ) -> HuggingFaceBgeEmbeddings:
119
+ model_kwargs = {"device": device}
120
+ encode_kwargs = {
121
+ "normalize_embeddings": True
122
+ } # set True to compute cosine similarity
123
+ embedding_model = HuggingFaceBgeEmbeddings(
124
+ model_name=model_name,
125
+ model_kwargs=model_kwargs,
126
+ encode_kwargs=encode_kwargs,
127
+ )
128
+ return embedding_model
129
+
130
+
131
+ def load_reranker_model(
132
+ reranker_model_name: str = "BAAI/bge-reranker-large", device: str = "cpu"
133
+ ) -> CrossEncoder:
134
+ reranker_model = CrossEncoder(
135
+ model_name=reranker_model_name, max_length=512, device=device
136
+ )
137
+ return reranker_model
138
+
139
+
140
+ def main(
141
+ file: str = "example_data/2401.08406.pdf",
142
+ query: Optional[str] = None,
143
+ llm_name="mistral",
144
+ ):
145
+ docs = load_pdf(files=file)
146
+
147
+ embedding_model = load_embedding_model()
148
+ retriever = create_parent_retriever(docs, embedding_model)
149
+ reranker_model = load_reranker_model()
150
+
151
+ context = retrieve_context(
152
+ query, retriever=retriever, reranker_model=reranker_model
153
+ )[0]
154
+ print("context:\n", context, "\n", "=" * 50, "\n")
155
+
156
+
157
+ if __name__ == "__main__":
158
+ from jsonargparse import CLI
159
+
160
+ CLI(main)