Spaces:
Runtime error
Runtime error
Upload 15 files
Browse files- Dockerfile +11 -0
- app.py +247 -0
- chainlit.md +14 -0
- myutils/__pycache__/finetuning.cpython-311.pyc +0 -0
- myutils/__pycache__/finetuning.cpython-312.pyc +0 -0
- myutils/__pycache__/prepare_data_for_finetuning.cpython-311.pyc +0 -0
- myutils/__pycache__/rag_pipeline_utils.cpython-311.pyc +0 -0
- myutils/__pycache__/rag_pipeline_utils.cpython-312.pyc +0 -0
- myutils/__pycache__/ragas_pipeline.cpython-311.pyc +0 -0
- myutils/__pycache__/ragas_pipeline.cpython-312.pyc +0 -0
- myutils/finetuning.py +410 -0
- myutils/pdfloader.py +87 -0
- myutils/rag_pipeline_utils.py +289 -0
- myutils/ragas_pipeline.py +86 -0
- requirements.txt +17 -0
Dockerfile
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM python:3.11
|
2 |
+
RUN useradd -m -u 1000 user
|
3 |
+
USER user
|
4 |
+
ENV HOME=/home/user \
|
5 |
+
PATH=/home/user/.local/bin:$PATH
|
6 |
+
WORKDIR $HOME/app
|
7 |
+
COPY --chown=user . $HOME/app
|
8 |
+
COPY ./requirements.txt ~/app/requirements.txt
|
9 |
+
RUN pip install -r requirements.txt
|
10 |
+
COPY . .
|
11 |
+
CMD ["chainlit", "run", "app.py", "--port", "7860"]
|
app.py
ADDED
@@ -0,0 +1,247 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
app_end_to_end_prototype.py
|
3 |
+
|
4 |
+
1. This app loads two pdf documents and allows the user to ask questions about these documents.
|
5 |
+
The documents that are used are:
|
6 |
+
|
7 |
+
https://www.whitehouse.gov/wp-content/uploads/2022/10/Blueprint-for-an-AI-Bill-of-Rights.pdf
|
8 |
+
AND
|
9 |
+
https://nvlpubs.nist.gov/nistpubs/ai/NIST.AI.600-1.pdf
|
10 |
+
|
11 |
+
2. The two documents are pre-processed on start. Here are brief details on the pre-processing:
|
12 |
+
a. text is split into chunks using langchain RecursiveCharacterTextSplitter method.
|
13 |
+
b. The text in each chunk is converted to an embedding using OpenAI text-embedding-3-small embeddings.
|
14 |
+
Each embedding produced by this model has dimension 1536.
|
15 |
+
Each chunk is therefore represented by an embedding of dimension 1536.
|
16 |
+
c. The collection of embeddings for all chunks along with metadata are saved/indexed in a vector database.
|
17 |
+
d. For this exercise, I use an in-memory version of Qdrant vector db.
|
18 |
+
|
19 |
+
3. The next step is to build a RAG pipeline to answer questions. This is implemented as follows:
|
20 |
+
a. I use a simple prompt that retrieves relevant contexts based on a user query.
|
21 |
+
b. First, the user query is encoded using the same embedding model as the documents.
|
22 |
+
c. Second, a set of relevant documents is returned by the retriever
|
23 |
+
which efficiently searches the vector db and returns the most relevant chunks.
|
24 |
+
d. Third, the user query and retrieved contexts are then passed to a chat-enabled LLM.
|
25 |
+
I use OpenAI's gpt-4o-mini throughout this exercise.
|
26 |
+
e. Fourth, the chat model processes the user query and context along with the prompt and
|
27 |
+
generates a response that is then passed to the user.
|
28 |
+
|
29 |
+
4. The cl.on_start initiates the conversation with the user.
|
30 |
+
|
31 |
+
5. The cl.on_message decorator wraps the main function
|
32 |
+
This function does the following:
|
33 |
+
a. receives the query that the user types in
|
34 |
+
b. runs the RAG pipeline
|
35 |
+
c. sends results back to UI for display
|
36 |
+
|
37 |
+
Additional Notes:
|
38 |
+
a. note the use of async functions and await async syntax throughout the module here!
|
39 |
+
b. note the use of yield rather than return in certain key functions
|
40 |
+
c. note the use of streaming capabilities when needed
|
41 |
+
|
42 |
+
"""
|
43 |
+
|
44 |
+
import os
|
45 |
+
from typing import List
|
46 |
+
from dotenv import load_dotenv
|
47 |
+
|
48 |
+
# chainlit imports
|
49 |
+
import chainlit as cl
|
50 |
+
|
51 |
+
# langchain imports
|
52 |
+
# document loader
|
53 |
+
from langchain_community.document_loaders import PyMuPDFLoader
|
54 |
+
# text splitter
|
55 |
+
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
56 |
+
# embeddings model to embed each chunk of text in doc
|
57 |
+
from langchain_openai import OpenAIEmbeddings
|
58 |
+
# vector store
|
59 |
+
# llm for text generation using prompt plus retrieved context plus query
|
60 |
+
from langchain_openai import ChatOpenAI
|
61 |
+
# templates to create custom prompts
|
62 |
+
from langchain_core.prompts import ChatPromptTemplate, PromptTemplate
|
63 |
+
# chains
|
64 |
+
# LCEL Runnable Passthrough
|
65 |
+
from langchain_core.runnables import RunnablePassthrough
|
66 |
+
# to parse output from llm
|
67 |
+
from langchain_core.output_parsers import StrOutputParser
|
68 |
+
from langchain.docstore.document import Document
|
69 |
+
from langchain_huggingface import HuggingFaceEmbeddings
|
70 |
+
from langchain.document_loaders import PyMuPDFLoader
|
71 |
+
|
72 |
+
from sentence_transformers import SentenceTransformer
|
73 |
+
|
74 |
+
from myutils.rag_pipeline_utils import SimpleTextSplitter, SemanticTextSplitter, VectorStore, AdvancedRetriever
|
75 |
+
from myutils.ragas_pipeline import RagasPipeline
|
76 |
+
from myutils.rag_pipeline_utils import load_all_pdfs, set_up_rag_pipeline
|
77 |
+
|
78 |
+
|
79 |
+
load_dotenv()
|
80 |
+
|
81 |
+
# Flag to indicate if pdfs should be loaded directly from URLs
|
82 |
+
# If True, get pdfs from urls; if false, get them from local copy
|
83 |
+
LOAD_PDF_DIRECTLY_FROM_URL = True
|
84 |
+
|
85 |
+
# set the APP_MODE
|
86 |
+
# one of two choices:
|
87 |
+
# early_prototype means use OpenAI embeddings
|
88 |
+
# advanced_prototype means use finetuned model embeddings
|
89 |
+
APP_MODE = "early_prototype"
|
90 |
+
|
91 |
+
if APP_MODE == "advanced_prototype":
|
92 |
+
embeddings = OpenAIEmbeddings(model='text-embedding-3-small')
|
93 |
+
embed_dim = 1536
|
94 |
+
appendix_to_user_message = "This chatbot is built using OpenAI Embeddings as a fast prototype."
|
95 |
+
else:
|
96 |
+
finetuned_model_id = "Vira21/finetuned_arctic"
|
97 |
+
arctic_finetuned_model = SentenceTransformer(finetuned_model_id)
|
98 |
+
embeddings = HuggingFaceEmbeddings(model_name="Vira21/finetuned_arctic")
|
99 |
+
appendix_to_user_message = "Our Tech team finetuned snowflake-arctic-embed-m to bring you this chatbot!!"
|
100 |
+
embed_dim = 768
|
101 |
+
|
102 |
+
rag_template = """
|
103 |
+
You are an assistant for question-answering tasks.
|
104 |
+
You will be given documents on the risks of AI, frameworks and
|
105 |
+
policies formulated by various governmental agencies to articulate
|
106 |
+
these risks and to safeguard against these risks.
|
107 |
+
|
108 |
+
Use the following pieces of retrieved context to answer
|
109 |
+
the question.
|
110 |
+
|
111 |
+
You must answer the question only based on the context provided.
|
112 |
+
|
113 |
+
If you don't know the answer or if the context does not provide sufficient information,
|
114 |
+
then say that you don't know.
|
115 |
+
|
116 |
+
If the user expresses gratitude or types a greeting, respond respectfully instead of saying
|
117 |
+
"I don't know." Acknowledge their message and kindly ask if they have any questions related to AI risks.
|
118 |
+
|
119 |
+
Think through your answer step-by-step.
|
120 |
+
|
121 |
+
Context:
|
122 |
+
{context}
|
123 |
+
|
124 |
+
Question:
|
125 |
+
{question}
|
126 |
+
"""
|
127 |
+
|
128 |
+
rag_prompt = ChatPromptTemplate.from_template(template=rag_template)
|
129 |
+
|
130 |
+
# parameters to manage text splitting/chunking
|
131 |
+
chunk_kwargs = {
|
132 |
+
'chunk_size': 1000,
|
133 |
+
'chunk_overlap': 300
|
134 |
+
}
|
135 |
+
|
136 |
+
retrieval_chain_kwargs = {
|
137 |
+
'location': ":memory:",
|
138 |
+
'collection_name': 'End_to_End_Prototype',
|
139 |
+
'embeddings': embeddings,
|
140 |
+
'embed_dim': embed_dim,
|
141 |
+
'prompt': rag_prompt,
|
142 |
+
'qa_llm': ChatOpenAI(model_name="gpt-4o-mini", temperature=0)
|
143 |
+
}
|
144 |
+
|
145 |
+
urls_for_pdfs = [
|
146 |
+
"https://www.whitehouse.gov/wp-content/uploads/2022/10/Blueprint-for-an-AI-Bill-of-Rights.pdf",
|
147 |
+
"https://nvlpubs.nist.gov/nistpubs/ai/NIST.AI.600-1.pdf"
|
148 |
+
]
|
149 |
+
|
150 |
+
pdf_file_paths = [
|
151 |
+
'./data/docs_for_rag/Blueprint-for-an-AI-Bill-of-Rights.pdf',
|
152 |
+
'./data/docs_for_rag/NIST.AI.600-1.pdf'
|
153 |
+
]
|
154 |
+
|
155 |
+
# if flag is True, then pass in pointers to URLs
|
156 |
+
# if flag is false, then pass in file pointers
|
157 |
+
if LOAD_PDF_DIRECTLY_FROM_URL:
|
158 |
+
docpathlist = urls_for_pdfs
|
159 |
+
else:
|
160 |
+
docpathlist = pdf_file_paths
|
161 |
+
|
162 |
+
|
163 |
+
class RetrievalAugmentedQAPipelineWithLangchain:
|
164 |
+
def __init__(self,
|
165 |
+
list_of_documents,
|
166 |
+
chunk_kwargs,
|
167 |
+
retrieval_chain_kwargs):
|
168 |
+
self.list_of_documents = list_of_documents
|
169 |
+
self.chunk_kwargs = chunk_kwargs
|
170 |
+
self.retrieval_chain_kwargs = retrieval_chain_kwargs
|
171 |
+
|
172 |
+
self.load_documents()
|
173 |
+
self.split_text()
|
174 |
+
self.set_up_rag_pipeline()
|
175 |
+
return
|
176 |
+
|
177 |
+
def load_documents(self):
|
178 |
+
self.documents = load_all_pdfs(self.list_of_documents)
|
179 |
+
return self
|
180 |
+
|
181 |
+
def split_text(self):
|
182 |
+
baseline_text_splitter = \
|
183 |
+
SimpleTextSplitter(**self.chunk_kwargs, documents=self.documents)
|
184 |
+
# split text for baseline case
|
185 |
+
self.baseline_text_splits = baseline_text_splitter.split_text()
|
186 |
+
return self
|
187 |
+
|
188 |
+
def set_up_rag_pipeline(self):
|
189 |
+
self.retrieval_chain = set_up_rag_pipeline(
|
190 |
+
**self.retrieval_chain_kwargs,
|
191 |
+
text_splits=self.baseline_text_splits
|
192 |
+
)
|
193 |
+
return self
|
194 |
+
|
195 |
+
|
196 |
+
RETRIEVAL_CHAIN = \
|
197 |
+
RetrievalAugmentedQAPipelineWithLangchain(
|
198 |
+
list_of_documents=docpathlist,
|
199 |
+
chunk_kwargs=chunk_kwargs,
|
200 |
+
retrieval_chain_kwargs=retrieval_chain_kwargs
|
201 |
+
).retrieval_chain
|
202 |
+
|
203 |
+
|
204 |
+
@cl.set_starters
|
205 |
+
async def set_starters():
|
206 |
+
return [
|
207 |
+
cl.Starter(
|
208 |
+
label="AI Bill of Rights",
|
209 |
+
message="What are the key principles outlined in the Blueprint for an AI Bill of Rights?",
|
210 |
+
description="Learn about the fundamental rights and protections proposed in the AI Bill of Rights",
|
211 |
+
),
|
212 |
+
cl.Starter(
|
213 |
+
label="AI Risk Assessment",
|
214 |
+
message="What are the main risks and challenges identified in the NIST AI Risk Management Framework?",
|
215 |
+
description="Understand key AI risks and mitigation strategies",
|
216 |
+
),
|
217 |
+
cl.Starter(
|
218 |
+
label="Data Privacy Protection",
|
219 |
+
message="How do these documents address data privacy and protection in AI systems?",
|
220 |
+
description="Explore guidelines for protecting personal data in AI applications",
|
221 |
+
),
|
222 |
+
cl.Starter(
|
223 |
+
label="AI System Testing",
|
224 |
+
message="What are the recommended approaches for testing and validating AI systems for safety and reliability?",
|
225 |
+
description="Learn about AI system validation and testing requirements",
|
226 |
+
),
|
227 |
+
]
|
228 |
+
|
229 |
+
@cl.on_chat_start
|
230 |
+
async def on_chat_start():
|
231 |
+
# Initialize the retrieval chain without sending a welcome message
|
232 |
+
cl.user_session.set("retrieval_chain", RETRIEVAL_CHAIN)
|
233 |
+
|
234 |
+
@cl.on_message
|
235 |
+
async def main(message):
|
236 |
+
retrieval_chain = cl.user_session.get("retrieval_chain")
|
237 |
+
|
238 |
+
msg = cl.Message(content="")
|
239 |
+
|
240 |
+
# result = await raqa_chain.invoke({"input": message.content})
|
241 |
+
result = await cl.make_async(retrieval_chain.invoke)({"question": message.content})
|
242 |
+
|
243 |
+
# async for stream_resp in result["answer"]:
|
244 |
+
for stream_resp in result["response"].content:
|
245 |
+
await msg.stream_token(stream_resp)
|
246 |
+
|
247 |
+
await msg.send()
|
chainlit.md
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Welcome to Chainlit! 🚀🤖
|
2 |
+
|
3 |
+
Hi there, Developer! 👋 We're excited to have you on board. Chainlit is a powerful tool designed to help you prototype, debug and share applications built on top of LLMs.
|
4 |
+
|
5 |
+
## Useful Links 🔗
|
6 |
+
|
7 |
+
- **Documentation:** Get started with our comprehensive [Chainlit Documentation](https://docs.chainlit.io) 📚
|
8 |
+
- **Discord Community:** Join our friendly [Chainlit Discord](https://discord.gg/k73SQ3FyUh) to ask questions, share your projects, and connect with other developers! 💬
|
9 |
+
|
10 |
+
We can't wait to see what you create with Chainlit! Happy coding! 💻😊
|
11 |
+
|
12 |
+
## Welcome screen
|
13 |
+
|
14 |
+
To modify the welcome screen, edit the `chainlit.md` file at the root of your project. If you do not want a welcome screen, just leave this file empty.
|
myutils/__pycache__/finetuning.cpython-311.pyc
ADDED
Binary file (21 kB). View file
|
|
myutils/__pycache__/finetuning.cpython-312.pyc
ADDED
Binary file (19.2 kB). View file
|
|
myutils/__pycache__/prepare_data_for_finetuning.cpython-311.pyc
ADDED
Binary file (14.7 kB). View file
|
|
myutils/__pycache__/rag_pipeline_utils.cpython-311.pyc
ADDED
Binary file (12.6 kB). View file
|
|
myutils/__pycache__/rag_pipeline_utils.cpython-312.pyc
ADDED
Binary file (11.1 kB). View file
|
|
myutils/__pycache__/ragas_pipeline.cpython-311.pyc
ADDED
Binary file (3.91 kB). View file
|
|
myutils/__pycache__/ragas_pipeline.cpython-312.pyc
ADDED
Binary file (3.35 kB). View file
|
|
myutils/finetuning.py
ADDED
@@ -0,0 +1,410 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
finetuning_pipeline.py
|
3 |
+
|
4 |
+
Collects a number of methods in classes to streamline the finetuning of model embeddings
|
5 |
+
|
6 |
+
|
7 |
+
#### Fine-tuning Steps
|
8 |
+
|
9 |
+
1. Prepare Train, Val and Test Data
|
10 |
+
- if needed, chunk data to get a list of LC Documents
|
11 |
+
- Split the list into train, val and test sub-groups
|
12 |
+
- For each sub-group, use an LLM to generate a list of POSITIVE question, context pairs.
|
13 |
+
- This is done by passing the context to the LLM along with a prompt to generate `n_questions` number of questions; the questions are extracted from the LLM output and paired with the underlying context. Note that each context will have more than one question paired with it.
|
14 |
+
- Write out the list of question, context pairs for train, val and test sub-groups into a jsonl file for future reference.
|
15 |
+
- The train sub-group is loaded into a HF Dataset object for use in training.
|
16 |
+
2. Data Loader
|
17 |
+
- Set up data loader
|
18 |
+
- This includes the training data along with batch size information.
|
19 |
+
3. Load model to be finetuned
|
20 |
+
- Use HF model name to load model
|
21 |
+
4. Set up loss function
|
22 |
+
- concept of inner loss: MultipleNegativesRankingLoss
|
23 |
+
- wrap inner loss in overall loss: MatryoshkaLoss
|
24 |
+
5. Set up finetuning pipeline
|
25 |
+
- This includes data, model, loss and hyperparameters
|
26 |
+
- Hyperparameters include number of epochs, warmup, etc.
|
27 |
+
6. Run the finetuning pipeline and get modified model embeddings
|
28 |
+
- save these embeddings
|
29 |
+
- see if these can be loaded onto HF
|
30 |
+
- see if these can be downloaded from HF
|
31 |
+
7. Validation Loss
|
32 |
+
- run assessment on val sub-group
|
33 |
+
|
34 |
+
|
35 |
+
"""
|
36 |
+
|
37 |
+
# imports
|
38 |
+
from operator import itemgetter
|
39 |
+
import pandas as pd
|
40 |
+
from typing import List
|
41 |
+
import uuid
|
42 |
+
import random
|
43 |
+
import tqdm
|
44 |
+
import re
|
45 |
+
import json
|
46 |
+
import pandas as pd
|
47 |
+
|
48 |
+
from torch.utils.data import DataLoader
|
49 |
+
|
50 |
+
from sentence_transformers import SentenceTransformer
|
51 |
+
from sentence_transformers import InputExample
|
52 |
+
from sentence_transformers.losses import MatryoshkaLoss, MultipleNegativesRankingLoss
|
53 |
+
from sentence_transformers.evaluation import InformationRetrievalEvaluator
|
54 |
+
|
55 |
+
from langchain_community.vectorstores import FAISS
|
56 |
+
from langchain_core.documents import Document
|
57 |
+
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
58 |
+
from langchain_core.prompts import ChatPromptTemplate
|
59 |
+
from langchain_openai import ChatOpenAI
|
60 |
+
from langchain_core.documents import Document
|
61 |
+
|
62 |
+
|
63 |
+
class GenerateQuestionsForContexts:
|
64 |
+
def __init__(self,
|
65 |
+
qa_chat_model_name="gpt-4o-mini",
|
66 |
+
n_questions=3):
|
67 |
+
|
68 |
+
self.qa_chat_model_name = qa_chat_model_name
|
69 |
+
# regex pattern used to extract questions from LLM response
|
70 |
+
# first group is question number - an integer - followed by a period
|
71 |
+
# second group is any character that follows this
|
72 |
+
self.regex_pattern = r'(^\d+).(.+)'
|
73 |
+
self.n_questions = n_questions
|
74 |
+
|
75 |
+
self.set_up_chat_model()
|
76 |
+
self.set_up_question_generation_chain()
|
77 |
+
return
|
78 |
+
|
79 |
+
def get_unique_id(self, id_set):
|
80 |
+
"""
|
81 |
+
Generate unique id not present in input set of ids
|
82 |
+
Input
|
83 |
+
a set of unique identifiers
|
84 |
+
Returns
|
85 |
+
a new unique id not in input set
|
86 |
+
updated input set of ids incl the newly generated id
|
87 |
+
"""
|
88 |
+
id = str(uuid.uuid4())
|
89 |
+
while id in id_set:
|
90 |
+
id = str(uuid.uuid4())
|
91 |
+
id_set.add(id)
|
92 |
+
return id, id_set
|
93 |
+
|
94 |
+
def set_up_chat_model(self):
|
95 |
+
self.qa_chat_model = ChatOpenAI(
|
96 |
+
model=self.qa_chat_model_name,
|
97 |
+
temperature=0
|
98 |
+
)
|
99 |
+
return self
|
100 |
+
|
101 |
+
def set_up_question_generation_chain(self):
|
102 |
+
qa_prompt = """\
|
103 |
+
Given the following context, you must generate questions based on only the provided context.
|
104 |
+
|
105 |
+
You are to generate {n_questions} questions which should be provided in the following format:
|
106 |
+
|
107 |
+
1. QUESTION #1
|
108 |
+
2. QUESTION #2
|
109 |
+
...
|
110 |
+
|
111 |
+
Context:
|
112 |
+
{context}
|
113 |
+
"""
|
114 |
+
qa_prompt_template = ChatPromptTemplate.from_template(qa_prompt)
|
115 |
+
self.question_generation_chain = qa_prompt_template | self.qa_chat_model
|
116 |
+
return self
|
117 |
+
|
118 |
+
def create_questions(self, documents, n_questions):
|
119 |
+
questions = {}
|
120 |
+
relevant_docs = {}
|
121 |
+
|
122 |
+
q_id_set = set()
|
123 |
+
for document in tqdm.tqdm(documents): # note tqdm.tqdm (NOT just tqdm as in original notebook)
|
124 |
+
this_question_set = \
|
125 |
+
self.question_generation_chain.invoke(
|
126 |
+
{
|
127 |
+
'context': document.page_content,
|
128 |
+
'n_questions': n_questions
|
129 |
+
}
|
130 |
+
)
|
131 |
+
for question in this_question_set.content.split("\n"):
|
132 |
+
if len(question) > 0:
|
133 |
+
try:
|
134 |
+
q_id, q_id_set = self.get_unique_id(q_id_set)
|
135 |
+
matched_pattern = re.search(self.regex_pattern, question) # regex search for n. <question>
|
136 |
+
if len(matched_pattern.group(2)) > 0:
|
137 |
+
questions[q_id] = matched_pattern.group(2).strip() # extraction of question string
|
138 |
+
relevant_docs[q_id] = [document.metadata["id"]]
|
139 |
+
except Exception:
|
140 |
+
continue
|
141 |
+
return questions, relevant_docs
|
142 |
+
|
143 |
+
|
144 |
+
class PrepareDataForFinetuning(GenerateQuestionsForContexts):
|
145 |
+
def __init__(self,
|
146 |
+
chunk_size=None, chunk_overlap=None, len_function=None,
|
147 |
+
lcdocuments=None, run_optional_text_splitter=False,
|
148 |
+
all_splits=None, train_val_test_size=[10, 5, 5],
|
149 |
+
train_val_test_split_type='random',
|
150 |
+
random_seed=69, qa_chat_model_name="gpt-4o-mini",
|
151 |
+
n_questions=2, batch_size=5):
|
152 |
+
|
153 |
+
super().__init__(qa_chat_model_name=qa_chat_model_name,
|
154 |
+
n_questions=n_questions)
|
155 |
+
|
156 |
+
self.chunk_size = chunk_size
|
157 |
+
self.chunk_overlap = chunk_overlap
|
158 |
+
self.len_function = len_function
|
159 |
+
|
160 |
+
self.lcdocuments = lcdocuments
|
161 |
+
self.run_optional_text_splitter = run_optional_text_splitter
|
162 |
+
|
163 |
+
self.all_doc_splits = all_splits
|
164 |
+
|
165 |
+
self.train_val_test_size = train_val_test_size
|
166 |
+
self.n_train = self.train_val_test_size[0]
|
167 |
+
self.n_val = self.train_val_test_size[1]
|
168 |
+
self.n_test = self.train_val_test_size[2]
|
169 |
+
self.train_val_test_split_type = train_val_test_split_type
|
170 |
+
|
171 |
+
self.random_seed = random_seed
|
172 |
+
self.batch_size = batch_size
|
173 |
+
return
|
174 |
+
|
175 |
+
def optional_text_splitter(self):
|
176 |
+
text_splitter = RecursiveCharacterTextSplitter(
|
177 |
+
chunk_size = self.chunk_size,
|
178 |
+
chunk_overlap = self.chunk_overlap,
|
179 |
+
length_function = self.len_function
|
180 |
+
)
|
181 |
+
self.all_doc_splits = text_splitter.split_documents(self.lcdocuments.load())
|
182 |
+
return self
|
183 |
+
|
184 |
+
def attach_unique_ids_to_docs(self):
|
185 |
+
id_set = set()
|
186 |
+
for docsplit in self.all_doc_splits:
|
187 |
+
id, id_set = self.get_unique_id(id_set)
|
188 |
+
docsplit.metadata["id"] = id
|
189 |
+
return self
|
190 |
+
|
191 |
+
def simple_train_val_test_splits(self):
|
192 |
+
self.training_splits = self.all_doc_splits[:self.n_train]
|
193 |
+
self.val_splits = self.all_doc_splits[self.n_train:self.n_train+self.n_val]
|
194 |
+
self.test_splits = self.all_doc_splits[self.n_train+self.n_val:]
|
195 |
+
return self
|
196 |
+
|
197 |
+
def randomized_train_val_test_splits(self):
|
198 |
+
# set the same seed to be able to replicate the result of
|
199 |
+
# random shuffle below
|
200 |
+
random.seed(self.random_seed)
|
201 |
+
|
202 |
+
# randomly orders the elements in the list training_documents
|
203 |
+
randomly_ordered_documents = self.all_doc_splits.copy()
|
204 |
+
random.shuffle(randomly_ordered_documents)
|
205 |
+
|
206 |
+
# assign slices to training, val and test
|
207 |
+
self.training_splits = randomly_ordered_documents[:self.n_train]
|
208 |
+
self.val_splits = randomly_ordered_documents[self.n_train: self.n_train+self.n_val]
|
209 |
+
self.test_splits = randomly_ordered_documents[self.n_train+self.n_val:]
|
210 |
+
return self
|
211 |
+
|
212 |
+
def get_all_questions(self):
|
213 |
+
self.training_questions, self.training_relevant_contexts = \
|
214 |
+
self.create_questions(documents=self.training_splits, n_questions=self.n_questions)
|
215 |
+
self.val_questions, self.val_relevant_contexts = \
|
216 |
+
self.create_questions(documents=self.val_splits, n_questions=self.n_questions)
|
217 |
+
self.test_questions, self.test_relevant_contexts = \
|
218 |
+
self.create_questions(documents=self.test_splits, n_questions=self.n_questions)
|
219 |
+
return self
|
220 |
+
|
221 |
+
def save_dataset_to_jsonl(self, splits, questions, relevant_contexts, jsonl_filename):
|
222 |
+
"""
|
223 |
+
NOTE: Each `jsonl` file has a single line! This is a nested JSON structure.
|
224 |
+
Primary keys for each file are `questions`, `relevant_contexts` and `corpus`.
|
225 |
+
1. Each `question` element is a json object with a key id for the
|
226 |
+
question and the string corresp to question as the value.
|
227 |
+
2. Each `relevant_contexts` element is a json object with key id
|
228 |
+
corresponding to a question id and value corresponding to a unique id for the context
|
229 |
+
3. Each `corpus` element is a json object with key id
|
230 |
+
corresponding to a unique context id and value being the context string.
|
231 |
+
"""
|
232 |
+
corpus = {item.metadata["id"] : item.page_content for item in splits}
|
233 |
+
dataset_dict = {
|
234 |
+
"questions" : questions,
|
235 |
+
"relevant_contexts" : relevant_contexts,
|
236 |
+
"corpus" : corpus
|
237 |
+
}
|
238 |
+
with open(jsonl_filename, "w") as f:
|
239 |
+
json.dump(dataset_dict, f)
|
240 |
+
return dataset_dict
|
241 |
+
|
242 |
+
def save_train_val_test_dataset_to_jsonl(self):
|
243 |
+
self.train_dataset = \
|
244 |
+
self.save_dataset_to_jsonl(self.training_splits,
|
245 |
+
self.training_questions,
|
246 |
+
self.training_relevant_contexts,
|
247 |
+
jsonl_filename='./data/finetuning_data/training_dataset.jsonl')
|
248 |
+
|
249 |
+
self.val_dataset = \
|
250 |
+
self.save_dataset_to_jsonl(self.val_splits,
|
251 |
+
self.val_questions,
|
252 |
+
self.val_relevant_contexts,
|
253 |
+
jsonl_filename='./data/finetuning_data/val_dataset.jsonl')
|
254 |
+
|
255 |
+
self.test_dataset = \
|
256 |
+
self.save_dataset_to_jsonl(self.test_splits,
|
257 |
+
self.test_questions,
|
258 |
+
self.test_relevant_contexts,
|
259 |
+
jsonl_filename='./data/finetuning_data/test_dataset.jsonl')
|
260 |
+
return self
|
261 |
+
|
262 |
+
def run_all_prep_data(self):
|
263 |
+
# if docs are passed in pre-chunking, then split docs
|
264 |
+
if self.run_optional_text_splitter is True:
|
265 |
+
self.optional_text_splitter()
|
266 |
+
|
267 |
+
# each chunk i.e., context gets a unique id
|
268 |
+
self.attach_unique_ids_to_docs()
|
269 |
+
|
270 |
+
# split into train, val and test - either random or simple slicing
|
271 |
+
if self.train_val_test_split_type.upper() == 'RANDOM':
|
272 |
+
self.randomized_train_val_test_splits()
|
273 |
+
else:
|
274 |
+
self.simple_train_val_test_splits()
|
275 |
+
|
276 |
+
# generate questions for each context
|
277 |
+
# this step involves large number of LLM calls
|
278 |
+
self.get_all_questions()
|
279 |
+
|
280 |
+
# save train, val and test datasets in jsonl format
|
281 |
+
self.save_train_val_test_dataset_to_jsonl()
|
282 |
+
return self
|
283 |
+
|
284 |
+
|
285 |
+
class FineTuneModel:
|
286 |
+
def __init__(self,
|
287 |
+
train_data,
|
288 |
+
val_data,
|
289 |
+
batch_size,
|
290 |
+
base_model_id='Snowflake/snowflake-arctic-embed-m',
|
291 |
+
matryoshka_dimensions=[768, 512, 256, 128, 64],
|
292 |
+
number_of_training_epochs=5,
|
293 |
+
finetuned_model_output_path='finetuned_arctic',
|
294 |
+
evaluation_steps = 50):
|
295 |
+
self.train_data = train_data
|
296 |
+
self.val_data = val_data
|
297 |
+
self.batch_size = batch_size
|
298 |
+
|
299 |
+
self.base_model_id = base_model_id
|
300 |
+
self.matryoshka_dimensions = matryoshka_dimensions
|
301 |
+
self.number_of_training_epochs = number_of_training_epochs
|
302 |
+
self.finetuned_model_output_path = finetuned_model_output_path
|
303 |
+
self.evaluation_steps = evaluation_steps
|
304 |
+
|
305 |
+
self.model = SentenceTransformer(self.base_model_id)
|
306 |
+
return
|
307 |
+
|
308 |
+
def prepare_data_for_finetuning(self, data):
|
309 |
+
corpus = data['corpus']
|
310 |
+
queries = data['questions']
|
311 |
+
relevant_docs = data['relevant_contexts']
|
312 |
+
return corpus, queries, relevant_docs
|
313 |
+
|
314 |
+
def get_data_loader(self):
|
315 |
+
corpus, queries, relevant_docs = self.prepare_data_for_finetuning(self.train_data)
|
316 |
+
|
317 |
+
examples = []
|
318 |
+
for query_id, query in queries.items():
|
319 |
+
doc_id = relevant_docs[query_id][0]
|
320 |
+
text = corpus[doc_id]
|
321 |
+
example = InputExample(texts=[query, text])
|
322 |
+
examples.append(example)
|
323 |
+
self.loader = DataLoader(examples, batch_size=self.batch_size)
|
324 |
+
return self
|
325 |
+
|
326 |
+
def loss_function(self):
|
327 |
+
inner_training_loss = MultipleNegativesRankingLoss(self.model)
|
328 |
+
self.train_loss = MatryoshkaLoss(
|
329 |
+
self.model,
|
330 |
+
inner_training_loss,
|
331 |
+
matryoshka_dims=self.matryoshka_dimensions
|
332 |
+
)
|
333 |
+
return self
|
334 |
+
|
335 |
+
def get_evaluator_for_val(self):
|
336 |
+
corpus, queries, relevant_docs = self.prepare_data_for_finetuning(self.val_data)
|
337 |
+
self.evaluator = InformationRetrievalEvaluator(queries, corpus, relevant_docs)
|
338 |
+
return self
|
339 |
+
|
340 |
+
def fit_model(self):
|
341 |
+
warmup_steps = int(len(self.loader) * self.number_of_training_epochs * 0.1)
|
342 |
+
self.model.fit(
|
343 |
+
train_objectives=[(self.loader, self.train_loss)],
|
344 |
+
epochs=self.number_of_training_epochs,
|
345 |
+
warmup_steps=warmup_steps,
|
346 |
+
output_path=self.finetuned_model_output_path,
|
347 |
+
show_progress_bar=True,
|
348 |
+
evaluator=self.evaluator,
|
349 |
+
evaluation_steps=self.evaluation_steps,
|
350 |
+
)
|
351 |
+
|
352 |
+
def run_steps_to_finetune_model(self):
|
353 |
+
# load train data into Loader
|
354 |
+
self.get_data_loader()
|
355 |
+
|
356 |
+
# set up loss function
|
357 |
+
self.loss_function()
|
358 |
+
|
359 |
+
# set up evaluator with val data
|
360 |
+
self.get_evaluator_for_val()
|
361 |
+
|
362 |
+
# finetune the model
|
363 |
+
self.fit_model()
|
364 |
+
return self
|
365 |
+
|
366 |
+
|
367 |
+
class FineTuneModelAndEvaluateRetriever(FineTuneModel):
|
368 |
+
def __init__(self,
|
369 |
+
train_data,
|
370 |
+
val_data,
|
371 |
+
test_data,
|
372 |
+
batch_size,
|
373 |
+
base_model_id='Snowflake/snowflake-arctic-embed-m',
|
374 |
+
matryoshka_dimensions=[768, 512, 256, 128, 64],
|
375 |
+
number_of_training_epochs=5,
|
376 |
+
finetuned_model_output_path='finetuned_arctic',
|
377 |
+
evaluation_steps = 50,
|
378 |
+
):
|
379 |
+
super().__init__(train_data=train_data,
|
380 |
+
val_data=val_data,
|
381 |
+
batch_size=batch_size,
|
382 |
+
base_model_id=base_model_id,
|
383 |
+
matryoshka_dimensions=matryoshka_dimensions,
|
384 |
+
number_of_training_epochs=number_of_training_epochs,
|
385 |
+
finetuned_model_output_path=finetuned_model_output_path,
|
386 |
+
evaluation_steps = evaluation_steps)
|
387 |
+
self.test_data = test_data
|
388 |
+
return
|
389 |
+
|
390 |
+
def set_up_test_data_for_retrieval(self, embedding_model_for_retrieval, top_k_for_retrieval):
|
391 |
+
corpus, questions, relevant_docs = self.prepare_data_for_finetuning(self.test_data)
|
392 |
+
|
393 |
+
documents = [Document(page_content=content, metadata={"id": doc_id})
|
394 |
+
for doc_id, content in corpus.items()]
|
395 |
+
|
396 |
+
vectorstore = FAISS.from_documents(documents, embedding_model_for_retrieval)
|
397 |
+
retriever = vectorstore.as_retriever(search_kwargs={"k": top_k_for_retrieval})
|
398 |
+
return corpus, questions, relevant_docs, retriever
|
399 |
+
|
400 |
+
def evaluate_embeddings_model(self, embedding_model_for_retrieval, top_k_for_retrieval, verbose=False):
|
401 |
+
corpus, questions, relevant_docs, retriever = \
|
402 |
+
self.set_up_test_data_for_retrieval(embedding_model_for_retrieval, top_k_for_retrieval)
|
403 |
+
eval_results = []
|
404 |
+
for id, question in tqdm.tqdm(questions.items()):
|
405 |
+
retrieved_nodes = retriever.invoke(question)
|
406 |
+
retrieved_ids = [node.metadata["id"] for node in retrieved_nodes]
|
407 |
+
expected_id = relevant_docs[id][0]
|
408 |
+
is_hit = expected_id in retrieved_ids
|
409 |
+
eval_results.append({"id": id, "question": question, "expected_id": expected_id, "is_hit": is_hit})
|
410 |
+
return eval_results
|
myutils/pdfloader.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
pdfloader.py
|
3 |
+
This class loads a list of pdf documents passed in
|
4 |
+
and returns a list of parsed text for these docs
|
5 |
+
|
6 |
+
User can provide one of a few options to load pdf...
|
7 |
+
pypdf or pymupdf
|
8 |
+
|
9 |
+
"""
|
10 |
+
|
11 |
+
# importing required classes
|
12 |
+
import os
|
13 |
+
from typing import List
|
14 |
+
|
15 |
+
from pypdf import PdfReader
|
16 |
+
import pymupdf
|
17 |
+
|
18 |
+
|
19 |
+
VALID_PDF_MODULES = ['pypdf', 'pymupdf']
|
20 |
+
|
21 |
+
|
22 |
+
class TextFromPdf:
|
23 |
+
'''
|
24 |
+
this class converts a list of pdf documents into a list of text documents
|
25 |
+
'''
|
26 |
+
def __init__(self,
|
27 |
+
pdfmodule: str,
|
28 |
+
list_of_pdf_docs: List[str]):
|
29 |
+
|
30 |
+
# validate pdfmodule
|
31 |
+
if pdfmodule in VALID_PDF_MODULES:
|
32 |
+
self.pdfmodule = pdfmodule
|
33 |
+
else:
|
34 |
+
print(f'ERROR: pdfmodule must be one of {VALID_PDF_MODULES}')
|
35 |
+
raise Exception
|
36 |
+
|
37 |
+
# validate input list
|
38 |
+
if isinstance(list_of_pdf_docs, list) and len(list_of_pdf_docs) > 0:
|
39 |
+
self.list_of_pdf_docs = list_of_pdf_docs
|
40 |
+
else:
|
41 |
+
print('ERROR: expecting a non-empty list of pdf names to be passed in')
|
42 |
+
raise Exception
|
43 |
+
return
|
44 |
+
|
45 |
+
def process_single_pdf_with_pypdf(self, pdfdoc):
|
46 |
+
# check if file exists; if not return None
|
47 |
+
if os.path.isfile(pdfdoc):
|
48 |
+
pass
|
49 |
+
else:
|
50 |
+
print(f'Warning: pdf file {pdfdoc} does not exist...skipping to next pdf file')
|
51 |
+
return None
|
52 |
+
reader = PdfReader(pdfdoc)
|
53 |
+
numpages = len(reader.pages)
|
54 |
+
thistext = ''
|
55 |
+
for pagecount in range(0, numpages):
|
56 |
+
page = reader.pages[pagecount]
|
57 |
+
pagetext = page.extract_text()
|
58 |
+
thistext = thistext + '\n ' + pagetext # adding a line break
|
59 |
+
# print('\n')
|
60 |
+
# print(thistext)
|
61 |
+
return thistext
|
62 |
+
|
63 |
+
|
64 |
+
def process_single_pdf_with_pymupdf(self, pdfdoc):
|
65 |
+
# check if file exists; if not return None
|
66 |
+
if os.path.isfile(pdfdoc):
|
67 |
+
pass
|
68 |
+
else:
|
69 |
+
print(f'Warning: pdf file {pdfdoc} does not exist...skipping to next pdf file')
|
70 |
+
return None
|
71 |
+
|
72 |
+
doc = pymupdf.open(pdfdoc) # open a document
|
73 |
+
thistext = ''
|
74 |
+
for page in doc:
|
75 |
+
pagetext = page.get_text() # get plain text (is in UTF-8)
|
76 |
+
thistext = thistext + '\n ' + pagetext # adding a line break
|
77 |
+
# print('\n')
|
78 |
+
# print(thistext)
|
79 |
+
return thistext
|
80 |
+
|
81 |
+
def process_all_pdfs(self):
|
82 |
+
list_of_texts = []
|
83 |
+
for pdfdoc in self.list_of_pdf_docs:
|
84 |
+
pdftext = self.process_single_pdf(pdfdoc)
|
85 |
+
if pdftext is not None:
|
86 |
+
list_of_texts.append([pdftext])
|
87 |
+
return list_of_texts
|
myutils/rag_pipeline_utils.py
ADDED
@@ -0,0 +1,289 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
rag_pipeline_utils.py
|
3 |
+
|
4 |
+
This python script implements various classes useful for a RAG pipeline.
|
5 |
+
|
6 |
+
Currently I have implemented:
|
7 |
+
|
8 |
+
Text splitting
|
9 |
+
SimpleTextSplitter: uses RecursiveTextSplitter
|
10 |
+
SemanticTextSplitter: uses SemanticChunker (different threshold types can be used)
|
11 |
+
|
12 |
+
VectorStore
|
13 |
+
currently only sets up Qdrant vector store in memory
|
14 |
+
|
15 |
+
AdvancedRetriever
|
16 |
+
simple retriever is a special case -
|
17 |
+
advanced retriever - currently implemented MultiQueryRetriever
|
18 |
+
|
19 |
+
"""
|
20 |
+
|
21 |
+
from operator import itemgetter
|
22 |
+
from typing import List
|
23 |
+
|
24 |
+
from langchain_core.runnables import RunnablePassthrough
|
25 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
26 |
+
from langchain_experimental.text_splitter import SemanticChunker
|
27 |
+
from langchain_openai.embeddings import OpenAIEmbeddings
|
28 |
+
from langchain_qdrant import QdrantVectorStore
|
29 |
+
|
30 |
+
from qdrant_client import QdrantClient
|
31 |
+
from qdrant_client.http.models import Distance, VectorParams
|
32 |
+
|
33 |
+
from langchain.retrievers.multi_query import MultiQueryRetriever
|
34 |
+
from langchain_community.document_loaders import PyMuPDFLoader
|
35 |
+
from langchain_core.documents import Document
|
36 |
+
from datasets import Dataset
|
37 |
+
|
38 |
+
from ragas import evaluate
|
39 |
+
|
40 |
+
|
41 |
+
def load_all_pdfs(list_of_pdf_files: List[str]) -> List[Document]:
|
42 |
+
alldocs = []
|
43 |
+
for pdffile in list_of_pdf_files:
|
44 |
+
thisdoc = PyMuPDFLoader(file_path=pdffile).load()
|
45 |
+
print(f'loaded {pdffile} with {len(thisdoc)} pages ')
|
46 |
+
alldocs.extend(thisdoc)
|
47 |
+
print(f'loaded all files: total number of pages: {len(alldocs)} ')
|
48 |
+
return alldocs
|
49 |
+
|
50 |
+
|
51 |
+
class SimpleTextSplitter:
|
52 |
+
def __init__(self,
|
53 |
+
chunk_size,
|
54 |
+
chunk_overlap,
|
55 |
+
documents):
|
56 |
+
self.chunk_size = chunk_size
|
57 |
+
self.chunk_overlap = chunk_overlap
|
58 |
+
self.documents = documents
|
59 |
+
return
|
60 |
+
|
61 |
+
def split_text(self):
|
62 |
+
text_splitter = RecursiveCharacterTextSplitter(
|
63 |
+
chunk_size=self.chunk_size,
|
64 |
+
chunk_overlap=self.chunk_overlap
|
65 |
+
)
|
66 |
+
all_splits = text_splitter.split_documents(self.documents)
|
67 |
+
return all_splits
|
68 |
+
|
69 |
+
|
70 |
+
class SemanticTextSplitter:
|
71 |
+
def __init__(self,
|
72 |
+
llm_embeddings=OpenAIEmbeddings(),
|
73 |
+
threshold_type="interquartile",
|
74 |
+
documents=None):
|
75 |
+
self.llm_embeddings = llm_embeddings
|
76 |
+
self.threshold_type = threshold_type
|
77 |
+
self.documents = documents
|
78 |
+
return
|
79 |
+
|
80 |
+
def split_text(self):
|
81 |
+
text_splitter = SemanticChunker(
|
82 |
+
embeddings=self.llm_embeddings,
|
83 |
+
breakpoint_threshold_type="interquartile"
|
84 |
+
)
|
85 |
+
|
86 |
+
print(f'loaded {len(self.documents)} to be split ')
|
87 |
+
all_splits = text_splitter.split_documents(self.documents)
|
88 |
+
print(f'returning docs split into {len(all_splits)} chunks ')
|
89 |
+
return all_splits
|
90 |
+
|
91 |
+
|
92 |
+
class VectorStore:
|
93 |
+
def __init__(self,
|
94 |
+
location,
|
95 |
+
name,
|
96 |
+
documents,
|
97 |
+
size,
|
98 |
+
embedding=OpenAIEmbeddings()):
|
99 |
+
self.location = location
|
100 |
+
self.name = name
|
101 |
+
self.size = size
|
102 |
+
self.documents = documents
|
103 |
+
self.embedding = embedding
|
104 |
+
|
105 |
+
self.qdrant_client = QdrantClient(self.location)
|
106 |
+
self.qdrant_client.create_collection(
|
107 |
+
collection_name=self.name,
|
108 |
+
vectors_config=VectorParams(size=self.size, distance=Distance.COSINE),
|
109 |
+
)
|
110 |
+
return
|
111 |
+
|
112 |
+
def set_up_vectorstore(self):
|
113 |
+
self.qdrant_vector_store = QdrantVectorStore(
|
114 |
+
client=self.qdrant_client,
|
115 |
+
collection_name=self.name,
|
116 |
+
embedding=self.embedding
|
117 |
+
)
|
118 |
+
|
119 |
+
self.qdrant_vector_store.add_documents(self.documents)
|
120 |
+
return self
|
121 |
+
|
122 |
+
|
123 |
+
class AdvancedRetriever:
|
124 |
+
def __init__(self,
|
125 |
+
vectorstore):
|
126 |
+
self.vectorstore = vectorstore
|
127 |
+
return
|
128 |
+
|
129 |
+
def set_up_simple_retriever(self):
|
130 |
+
simple_retriever = self.vectorstore.as_retriever(
|
131 |
+
search_type='similarity',
|
132 |
+
search_kwargs={
|
133 |
+
'k': 5
|
134 |
+
}
|
135 |
+
)
|
136 |
+
return simple_retriever
|
137 |
+
|
138 |
+
def set_up_multi_query_retriever(self, llm):
|
139 |
+
retriever = self.set_up_simple_retriever()
|
140 |
+
advanced_retriever = MultiQueryRetriever.from_llm(
|
141 |
+
retriever=retriever, llm=llm
|
142 |
+
)
|
143 |
+
return advanced_retriever
|
144 |
+
|
145 |
+
|
146 |
+
def run_and_eval_rag_pipeline(location, collection_name, embed_dim, text_splits, embeddings,
|
147 |
+
prompt, qa_llm, metrics, test_df):
|
148 |
+
"""
|
149 |
+
Helper function that runs and evaluates different rag pipelines
|
150 |
+
based on different text_splits presented to the pipeline
|
151 |
+
"""
|
152 |
+
# vector store
|
153 |
+
vs = VectorStore(location=location,
|
154 |
+
name=collection_name,
|
155 |
+
documents=text_splits,
|
156 |
+
size=embed_dim,
|
157 |
+
embedding=embeddings)
|
158 |
+
|
159 |
+
qdvs = vs.set_up_vectorstore().qdrant_vector_store
|
160 |
+
|
161 |
+
# retriever
|
162 |
+
retriever = AdvancedRetriever(vectorstore=qdvs).set_up_simple_retriever()
|
163 |
+
|
164 |
+
# q&a chain using LCEL
|
165 |
+
retrieval_chain = (
|
166 |
+
{"context": itemgetter("question") | retriever, "question": itemgetter("question")}
|
167 |
+
| RunnablePassthrough.assign(context=itemgetter("context"))
|
168 |
+
| {"response": prompt | qa_llm, "context": itemgetter("context")}
|
169 |
+
)
|
170 |
+
|
171 |
+
# get questions, and ground-truth
|
172 |
+
test_questions = test_df["question"].values.tolist()
|
173 |
+
test_groundtruths = test_df["ground_truth"].values.tolist()
|
174 |
+
|
175 |
+
|
176 |
+
# run RAG pipeline
|
177 |
+
answers = []
|
178 |
+
contexts = []
|
179 |
+
|
180 |
+
for question in test_questions:
|
181 |
+
response = retrieval_chain.invoke({"question" : question})
|
182 |
+
answers.append(response["response"].content)
|
183 |
+
contexts.append([context.page_content for context in response["context"]])
|
184 |
+
|
185 |
+
# Save RAG pipeline results to HF Dataset object
|
186 |
+
response_dataset = Dataset.from_dict({
|
187 |
+
"question" : test_questions,
|
188 |
+
"answer" : answers,
|
189 |
+
"contexts" : contexts,
|
190 |
+
"ground_truth" : test_groundtruths
|
191 |
+
})
|
192 |
+
|
193 |
+
# Run RAGAS Evaluation - using metrics
|
194 |
+
results = evaluate(response_dataset, metrics)
|
195 |
+
|
196 |
+
# save results to df
|
197 |
+
results_df = results.to_pandas()
|
198 |
+
|
199 |
+
return results, results_df
|
200 |
+
|
201 |
+
|
202 |
+
def set_up_rag_pipeline(location, collection_name,
|
203 |
+
embeddings, embed_dim,
|
204 |
+
prompt, qa_llm,
|
205 |
+
text_splits,):
|
206 |
+
"""
|
207 |
+
Helper function that sets up a RAG pipeline
|
208 |
+
Inputs
|
209 |
+
location: memory or persistent store
|
210 |
+
collection_name: name of collection, string
|
211 |
+
embeddings: object referring to embeddings to be used
|
212 |
+
embed_dim: embedding dimension
|
213 |
+
prompt: prompt used in RAG pipeline
|
214 |
+
qa_llm: LLM used to generate response
|
215 |
+
text_splits: list containing text splits
|
216 |
+
|
217 |
+
|
218 |
+
Returns a retrieval chain
|
219 |
+
"""
|
220 |
+
# vector store
|
221 |
+
vs = VectorStore(location=location,
|
222 |
+
name=collection_name,
|
223 |
+
documents=text_splits,
|
224 |
+
size=embed_dim,
|
225 |
+
embedding=embeddings)
|
226 |
+
|
227 |
+
qdvs = vs.set_up_vectorstore().qdrant_vector_store
|
228 |
+
|
229 |
+
# retriever
|
230 |
+
retriever = AdvancedRetriever(vectorstore=qdvs).set_up_simple_retriever()
|
231 |
+
|
232 |
+
# q&a chain using LCEL
|
233 |
+
retrieval_chain = (
|
234 |
+
{"context": itemgetter("question") | retriever, "question": itemgetter("question")}
|
235 |
+
| RunnablePassthrough.assign(context=itemgetter("context"))
|
236 |
+
| {"response": prompt | qa_llm, "context": itemgetter("context")}
|
237 |
+
)
|
238 |
+
|
239 |
+
return retrieval_chain
|
240 |
+
|
241 |
+
|
242 |
+
def test_rag_pipeline(retrieval_chain, list_of_questions):
|
243 |
+
"""
|
244 |
+
Tests RAG pipeline
|
245 |
+
Inputs
|
246 |
+
retrieval_chain: retrieval chain
|
247 |
+
list_of_questions: list of questions to use to test RAG pipeline
|
248 |
+
Output
|
249 |
+
List of RAG-pipeline-generated responses to each question
|
250 |
+
"""
|
251 |
+
all_answers = []
|
252 |
+
for i, question in enumerate(list_of_questions):
|
253 |
+
response = retrieval_chain.invoke({'question': question})
|
254 |
+
answer = response["response"].content
|
255 |
+
all_answers.append(answer)
|
256 |
+
return all_answers
|
257 |
+
|
258 |
+
|
259 |
+
def get_vibe_check_on_list_of_questions(collection_name,
|
260 |
+
embeddings, embed_dim,
|
261 |
+
prompt, llm, text_splits,
|
262 |
+
list_of_questions):
|
263 |
+
"""
|
264 |
+
HELPER FUNCTION
|
265 |
+
set up retrieval chain for each scenario and print out results
|
266 |
+
of the q_and_a for any list of questions
|
267 |
+
"""
|
268 |
+
|
269 |
+
# set up baseline retriever
|
270 |
+
retrieval_chain = \
|
271 |
+
set_up_rag_pipeline(location=":memory:", collection_name=collection_name,
|
272 |
+
embeddings=embeddings, embed_dim=embed_dim,
|
273 |
+
prompt=prompt, qa_llm=llm,
|
274 |
+
text_splits=text_splits)
|
275 |
+
|
276 |
+
# run RAG pipeline and get responses
|
277 |
+
answers = test_rag_pipeline(retrieval_chain, list_of_questions)
|
278 |
+
|
279 |
+
# create question, answer tuples
|
280 |
+
q_and_a = [(x, y) for x, y in zip(list_of_questions, answers)]
|
281 |
+
|
282 |
+
# print out question/answer pairs to review the performance of the pipeline
|
283 |
+
for i, item in enumerate(q_and_a):
|
284 |
+
print('=================')
|
285 |
+
print(f'=====question number: {i} =============')
|
286 |
+
print(item[0])
|
287 |
+
print(item[1])
|
288 |
+
|
289 |
+
return retrieval_chain, q_and_a
|
myutils/ragas_pipeline.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
ragas_pipeline.py
|
3 |
+
|
4 |
+
Implements the core pipeline to generate test set for RAGAS.
|
5 |
+
|
6 |
+
"""
|
7 |
+
|
8 |
+
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
|
9 |
+
from ragas.testset.generator import TestsetGenerator
|
10 |
+
from ragas import evaluate
|
11 |
+
|
12 |
+
from datasets import Dataset
|
13 |
+
|
14 |
+
from myutils.rag_pipeline_utils import SimpleTextSplitter, SemanticTextSplitter, VectorStore, AdvancedRetriever
|
15 |
+
|
16 |
+
|
17 |
+
class RagasPipeline:
|
18 |
+
def __init__(self, generator_llm_model, critic_llm_model, embedding_model,
|
19 |
+
number_of_qa_pairs,
|
20 |
+
chunk_size, chunk_overlap, documents,
|
21 |
+
distributions):
|
22 |
+
self.generator_llm = ChatOpenAI(model=generator_llm_model)
|
23 |
+
self.critic_llm = ChatOpenAI(model=critic_llm_model)
|
24 |
+
self.embeddings = OpenAIEmbeddings(model=embedding_model)
|
25 |
+
self.number_of_qa_pairs = number_of_qa_pairs
|
26 |
+
|
27 |
+
self.chunk_size = chunk_size
|
28 |
+
self.chunk_overlap = chunk_overlap
|
29 |
+
self.documents = documents
|
30 |
+
|
31 |
+
self.distributions = distributions
|
32 |
+
|
33 |
+
self.generator = TestsetGenerator.from_langchain(
|
34 |
+
self.generator_llm,
|
35 |
+
self.critic_llm,
|
36 |
+
self.embeddings
|
37 |
+
)
|
38 |
+
return
|
39 |
+
|
40 |
+
def generate_testset(self):
|
41 |
+
text_splitter = SimpleTextSplitter(
|
42 |
+
chunk_size=self.chunk_size,
|
43 |
+
chunk_overlap=self.chunk_overlap,
|
44 |
+
documents=self.documents
|
45 |
+
)
|
46 |
+
ragas_text_splits = text_splitter.split_text()
|
47 |
+
|
48 |
+
testset = self.generator.generate_with_langchain_docs(
|
49 |
+
ragas_text_splits,
|
50 |
+
self.number_of_qa_pairs,
|
51 |
+
self.distributions
|
52 |
+
)
|
53 |
+
|
54 |
+
testset_df = testset.to_pandas()
|
55 |
+
return testset_df
|
56 |
+
|
57 |
+
def ragas_eval_of_rag_pipeline(self, retrieval_chain, ragas_questions, ragas_groundtruths, ragas_metrics):
|
58 |
+
"""
|
59 |
+
Helper function that runs and evaluates different rag pipelines
|
60 |
+
based on RAGAS test questions
|
61 |
+
"""
|
62 |
+
|
63 |
+
# run RAG pipeline on RAGAS synthetic questions
|
64 |
+
answers = []
|
65 |
+
contexts = []
|
66 |
+
|
67 |
+
for question in ragas_questions:
|
68 |
+
response = retrieval_chain.invoke({"question" : question})
|
69 |
+
answers.append(response["response"].content)
|
70 |
+
contexts.append([context.page_content for context in response["context"]])
|
71 |
+
|
72 |
+
# Save RAG pipeline results to HF Dataset object
|
73 |
+
response_dataset = Dataset.from_dict({
|
74 |
+
"question" : ragas_questions,
|
75 |
+
"answer" : answers,
|
76 |
+
"contexts" : contexts,
|
77 |
+
"ground_truth" : ragas_groundtruths
|
78 |
+
})
|
79 |
+
|
80 |
+
# Run RAGAS Evaluation - using metrics
|
81 |
+
results = evaluate(response_dataset, ragas_metrics)
|
82 |
+
|
83 |
+
# save results to df
|
84 |
+
results_df = results.to_pandas()
|
85 |
+
|
86 |
+
return results, results_df
|
requirements.txt
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
langchain
|
2 |
+
langchain-openai
|
3 |
+
langchain_core==0.2.38
|
4 |
+
langchain-community
|
5 |
+
langchainhub
|
6 |
+
langchain-qdrant
|
7 |
+
langchain_huggingface
|
8 |
+
langchain-text-splitters
|
9 |
+
langchain_experimental
|
10 |
+
ragas==0.1.16
|
11 |
+
openai
|
12 |
+
pymupdf
|
13 |
+
faiss-cpu
|
14 |
+
sentence_transformers
|
15 |
+
datasets
|
16 |
+
pyarrow==14.0.1
|
17 |
+
chainlit
|