JulsdL commited on
Commit
c4eb0c2
·
1 Parent(s): 881300d

Refactored the code for better maintainability in preparation for LangGraph multi-agent implementation

Browse files
CHANGELOG.md CHANGED
@@ -1,3 +1,13 @@
 
 
 
 
 
 
 
 
 
 
1
  version 0.1.0 [2024-05-13]
2
 
3
  ## Added
 
1
+ version 0.1.1 [2024-05-13]
2
+
3
+ ## Modified
4
+
5
+ - Modularization: The code has been broken down into several modules, each with a specific responsibility. This makes the code easier to understand, test, and maintain. For example, the DocumentManager class in document_processing.py is responsible for managing documents and retrieving information from them. Similarly, the RetrievalManager class in retrieval.py is responsible for processing questions using a retrieval-augmented QA chain and returning the response.
6
+
7
+ - Separation of Concerns: The frontend and backend logic have been separated into different files (chainlit_frontend.py and document_processing.py, retrieval.py, etc.), which makes the codebase easier to navigate and maintain.
8
+
9
+ - Encapsulation: The code now makes use of classes and methods to encapsulate related functionality. For instance, the DocumentManager class encapsulates the functionality related to document management, and the RetrievalManager class encapsulates the functionality related to question processing and response retrieval.
10
+
11
  version 0.1.0 [2024-05-13]
12
 
13
  ## Added
README.md CHANGED
@@ -28,9 +28,9 @@ OPENAI_API_KEY=your-key-here
28
  4. Run the application using the following command:
29
 
30
  ```bash
31
- chainlit run app.py
32
  ```
33
 
34
  ## Usage
35
 
36
- Start a chat session and upload a Jupyter notebook file. The application will process the document and you can then ask questions related to the content of the notebook. It might take some time to answer some question, so please be patient.
 
28
  4. Run the application using the following command:
29
 
30
  ```bash
31
+ chainlit run aims_tutor/app.py
32
  ```
33
 
34
  ## Usage
35
 
36
+ Start a chat session and upload a Jupyter notebook file. The application will process the document and you can then ask questions related to the content of the notebook. It might take some time to answer some question (should be less than 1 min), so please be patient.
aims_tutor/__init__.py ADDED
File without changes
aims_tutor/app.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from dotenv import load_dotenv
3
+ import aims_tutor.chainlit_frontend as cl_frontend
4
+
5
+ # Load environment variables
6
+ load_dotenv()
7
+
8
+ # Main entry point
9
+ if __name__ == "__main__":
10
+ cl_frontend.start_chat()
11
+ cl_frontend.handle_user_query()
aims_tutor/chainlit_frontend.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import chainlit as cl
2
+ from dotenv import load_dotenv
3
+ from document_processing import DocumentManager
4
+ from retrieval import RetrievalManager
5
+
6
+ # Load environment variables
7
+ load_dotenv()
8
+
9
+ @cl.on_chat_start
10
+ async def start_chat():
11
+ settings = {
12
+ "model": "gpt-3.5-turbo",
13
+ "temperature": 0,
14
+ "top_p": 1,
15
+ "frequency_penalty": 0,
16
+ "presence_penalty": 0,
17
+ }
18
+ cl.user_session.set("settings", settings)
19
+ welcome_message = "Welcome to the AIMS-Tutor! Please upload a Jupyter notebook (.ipynb and max. 5mb) to start."
20
+ await cl.Message(content=welcome_message).send()
21
+
22
+ files = None
23
+ while files is None:
24
+ files = await cl.AskFileMessage(
25
+ content="Please upload a Jupyter notebook (.ipynb, max. 5mb):",
26
+ accept={"application/x-ipynb+json": [".ipynb"]},
27
+ max_size_mb=5
28
+ ).send()
29
+
30
+ file = files[0] # Get the first file
31
+ if file:
32
+ notebook_path = file.path
33
+ doc_manager = DocumentManager(notebook_path)
34
+ doc_manager.load_document()
35
+ doc_manager.initialize_retriever()
36
+ cl.user_session.set("docs", doc_manager.get_documents())
37
+ cl.user_session.set("retrieval_manager", RetrievalManager(doc_manager.get_retriever()))
38
+
39
+ @cl.on_message
40
+ async def main(message: cl.Message):
41
+ # Retrieve the multi-query retriever from session
42
+ retrieval_manager = cl.user_session.get("retrieval_manager")
43
+ if not retrieval_manager:
44
+ await cl.Message(content="No document processing setup found. Please upload a Jupyter notebook first.").send()
45
+ return
46
+
47
+ question = message.content
48
+ response = retrieval_manager.notebook_QA(question) # Process the question
49
+
50
+ msg = cl.Message(content=response)
51
+ await msg.send()
aims_tutor/document_processing.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from langchain_community.document_loaders import NotebookLoader
3
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
4
+ from langchain_community.vectorstores import Qdrant
5
+ from langchain.retrievers import MultiQueryRetriever
6
+ from langchain_openai.embeddings import OpenAIEmbeddings
7
+ from langchain_openai import ChatOpenAI
8
+ from dotenv import load_dotenv
9
+ from aims_tutor.utils import tiktoken_len
10
+
11
+ # Load environment variables
12
+ load_dotenv()
13
+
14
+ # Configuration for OpenAI
15
+ OPENAI_API_KEY = os.environ["OPENAI_API_KEY"]
16
+ openai_chat_model = ChatOpenAI(model="gpt-4-turbo", temperature=0.1)
17
+
18
+ class DocumentManager:
19
+ """
20
+ A class for managing documents and retrieving information from them.
21
+
22
+ Attributes:
23
+ notebook_path (str): The path to the notebook file.
24
+ docs (list): A list of loaded documents.
25
+ retriever (object): The retriever object used for document retrieval.
26
+
27
+ Methods:
28
+ load_document(): Loads the documents from the notebook file.
29
+ initialize_retriever(): Initializes the retriever object for document retrieval.
30
+ get_retriever(): Returns the retriever object.
31
+ get_documents(): Returns the loaded documents.
32
+ """
33
+ def __init__(self, notebook_path):
34
+ self.notebook_path = notebook_path
35
+ self.docs = None
36
+ self.retriever = None
37
+
38
+ def load_document(self):
39
+ """
40
+ Loads the documents from the notebook file.
41
+
42
+ This method initializes a `NotebookLoader` object with the specified parameters and uses it to load the documents from the notebook file. The loaded documents are stored in the `docs` attribute of the `DocumentManager` instance.
43
+
44
+ Parameters:
45
+ None
46
+
47
+ Returns:
48
+ None
49
+
50
+ Raises:
51
+ None
52
+ """
53
+ loader = NotebookLoader(
54
+ self.notebook_path,
55
+ include_outputs=False,
56
+ max_output_length=20,
57
+ remove_newline=True,
58
+ traceback=False
59
+ )
60
+ self.docs = loader.load()
61
+
62
+ def initialize_retriever(self):
63
+ """
64
+ A class for managing documents and retrieving information from them.
65
+
66
+ Attributes:
67
+ notebook_path (str): The path to the notebook file.
68
+ docs (list): A list of loaded documents.
69
+ retriever (object): The retriever object used for document retrieval.
70
+
71
+ Methods:
72
+ load_document(): Loads the documents from the notebook file.
73
+ initialize_retriever(): Initializes the retriever object for document retrieval.
74
+ get_retriever(): Returns the retriever object.
75
+ get_documents(): Returns the loaded documents.
76
+ """
77
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=200, chunk_overlap=50, length_function=tiktoken_len)
78
+
79
+ split_chunks = text_splitter.split_documents(self.docs)
80
+
81
+ embedding_model = OpenAIEmbeddings(model="text-embedding-3-small")
82
+
83
+ qdrant_vectorstore = Qdrant.from_documents(split_chunks, embedding_model, location=":memory:", collection_name="Notebook")
84
+
85
+ qdrant_retriever = qdrant_vectorstore.as_retriever() # Set the Qdrant vector store as a retriever
86
+
87
+ multiquery_retriever = MultiQueryRetriever.from_llm(retriever=qdrant_retriever, llm=openai_chat_model, include_original=True) # Create a multi-query retriever on top of the Qdrant retriever
88
+
89
+ self.retriever = multiquery_retriever
90
+
91
+ def get_retriever(self):
92
+ return self.retriever
93
+
94
+ def get_documents(self):
95
+ return self.docs
aims_tutor/prompt_templates.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_core.prompts import ChatPromptTemplate
2
+
3
+ class PromptTemplates:
4
+ """
5
+ The PromptTemplates class represents a collection of prompt templates used for generating chat prompts.
6
+
7
+ Attributes:
8
+ rag_QA_prompt (ChatPromptTemplate): A prompt template for generating RAG QA prompts.
9
+
10
+ Methods:
11
+ __init__(): Initializes all prompt templates as instance variables.
12
+ get_rag_qa_prompt(): Returns the RAG QA prompt.
13
+
14
+ Example usage:
15
+ prompt_templates = PromptTemplates()
16
+ rag_qa_prompt = prompt_templates.get_rag_qa_prompt()
17
+ """
18
+ def __init__(self):
19
+ # Initializes all prompt templates as instance variables
20
+ self.rag_QA_prompt = ChatPromptTemplate.from_template("""
21
+ CONTEXT:
22
+ {context}
23
+
24
+ QUERY:
25
+ {question}
26
+
27
+ Answer the query in a pretty format if the context is related to it; otherwise, answer: 'Sorry, I can't answer. Please ask another question.'
28
+ """)
29
+
30
+ def get_rag_qa_prompt(self):
31
+ # Returns the RAG QA prompt
32
+ return self.rag_QA_prompt
aims_tutor/retrieval.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_core.runnables import RunnablePassthrough
2
+ from langchain_openai import ChatOpenAI
3
+ from operator import itemgetter
4
+ from prompt_templates import PromptTemplates
5
+
6
+
7
+ class RetrievalManager:
8
+ """
9
+ RetrievalManager class.
10
+
11
+ This class represents a retrieval manager that processes questions using a retrieval-augmented QA chain and returns the response.
12
+
13
+ Attributes:
14
+ retriever (object): The retriever object used for retrieval.
15
+ chat_model (object): The ChatOpenAI object representing the OpenAI Chat model.
16
+
17
+ Methods:
18
+ notebook_QA(question):
19
+ Processes a question using the retrieval-augmented QA chain and returns the response.
20
+ """
21
+ def __init__(self, retriever):
22
+ self.retriever = retriever
23
+ self.chat_model = ChatOpenAI(model="gpt-4-turbo", temperature=0.1)
24
+ self.prompts = PromptTemplates()
25
+
26
+ def notebook_QA(self, question):
27
+ """
28
+ Processes a question using the retrieval-augmented QA chain and returns the response.
29
+
30
+ Parameters:
31
+ question (str): The question to be processed.
32
+
33
+ Returns:
34
+ str: The response generated by the retrieval-augmented QA chain.
35
+ """
36
+ retrieval_augmented_qa_chain = (
37
+ {"context": itemgetter("question") | self.retriever, "question": itemgetter("question")}
38
+ | RunnablePassthrough.assign(context=itemgetter("context"))
39
+ | {"response": self.prompts.get_rag_qa_prompt() | self.chat_model, "context": itemgetter("context")}
40
+ )
41
+
42
+ response = retrieval_augmented_qa_chain.invoke({"question": question})
43
+
44
+ return response["response"].content
aims_tutor/utils.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ import tiktoken
2
+
3
+ def tiktoken_len(text):
4
+ tokens = tiktoken.encoding_for_model("gpt-3.5-turbo").encode(text)
5
+ return len(tokens)
app.py → main.py RENAMED
@@ -20,7 +20,7 @@ load_dotenv()
20
 
21
  # Configuration for OpenAI
22
  OPENAI_API_KEY = os.environ["OPENAI_API_KEY"]
23
- openai_chat_model = ChatOpenAI(model="gpt-4-turbo", temperature=0)
24
 
25
  # Define the RAG prompt
26
  RAG_PROMPT = """
@@ -68,7 +68,7 @@ async def start_chat():
68
 
69
  loader = NotebookLoader(
70
  notebook_path,
71
- include_outputs=True,
72
  max_output_length=20,
73
  remove_newline=True,
74
  traceback=False
@@ -82,33 +82,18 @@ async def start_chat():
82
  embedding_model = OpenAIEmbeddings(model="text-embedding-3-small") # Initialize the embedding model
83
  qdrant_vectorstore = Qdrant.from_documents(split_chunks, embedding_model, location=":memory:", collection_name="Notebook") # Create a Qdrant vector store
84
  qdrant_retriever = qdrant_vectorstore.as_retriever() # Set the Qdrant vector store as a retriever
85
- multiquery_retriever = MultiQueryRetriever.from_llm(retriever=qdrant_retriever, llm=openai_chat_model) # Create a multi-query retriever on top of the Qdrant retriever
86
 
87
  # Store the multiquery_retriever in the user session
88
  cl.user_session.set("multiquery_retriever", multiquery_retriever)
89
 
90
- @cl.on_message
91
- async def main(message: cl.Message):
92
- # Retrieve the multi-query retriever from session
93
- multiquery_retriever = cl.user_session.get("multiquery_retriever")
94
-
95
- if not multiquery_retriever:
96
- await message.reply("No document processing chain found. Please upload a Jupyter notebook first.")
97
- return
98
-
99
- question = message.content
100
- response = handle_query(question, multiquery_retriever) # Process the question
101
-
102
- msg = cl.Message(content=response)
103
- await msg.send()
104
-
105
 
106
  @cl.on_message
107
  async def main(message: cl.Message):
108
  # Retrieve the multi-query retriever from session
109
  multiquery_retriever = cl.user_session.get("multiquery_retriever")
110
  if not multiquery_retriever:
111
- await message.reply("No document processing setup found. Please upload a Jupyter notebook first.")
112
  return
113
 
114
  question = message.content
 
20
 
21
  # Configuration for OpenAI
22
  OPENAI_API_KEY = os.environ["OPENAI_API_KEY"]
23
+ openai_chat_model = ChatOpenAI(model="gpt-4-turbo", temperature=0.1)
24
 
25
  # Define the RAG prompt
26
  RAG_PROMPT = """
 
68
 
69
  loader = NotebookLoader(
70
  notebook_path,
71
+ include_outputs=False,
72
  max_output_length=20,
73
  remove_newline=True,
74
  traceback=False
 
82
  embedding_model = OpenAIEmbeddings(model="text-embedding-3-small") # Initialize the embedding model
83
  qdrant_vectorstore = Qdrant.from_documents(split_chunks, embedding_model, location=":memory:", collection_name="Notebook") # Create a Qdrant vector store
84
  qdrant_retriever = qdrant_vectorstore.as_retriever() # Set the Qdrant vector store as a retriever
85
+ multiquery_retriever = MultiQueryRetriever.from_llm(retriever=qdrant_retriever, llm=openai_chat_model, include_original=True) # Create a multi-query retriever on top of the Qdrant retriever
86
 
87
  # Store the multiquery_retriever in the user session
88
  cl.user_session.set("multiquery_retriever", multiquery_retriever)
89
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
 
91
  @cl.on_message
92
  async def main(message: cl.Message):
93
  # Retrieve the multi-query retriever from session
94
  multiquery_retriever = cl.user_session.get("multiquery_retriever")
95
  if not multiquery_retriever:
96
+ await cl.Message(content="No document processing setup found. Please upload a Jupyter notebook first.").send()
97
  return
98
 
99
  question = message.content