JulsdL commited on
Commit
e3c5c37
·
unverified ·
2 Parent(s): 85e07b4 1d72eb0

Merge pull request #2 from JulsdL/quiz_functionnality

Browse files

Implementation of Quiz Functionality with LangGraph Integration

CHANGELOG.md CHANGED
@@ -1,3 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  version 0.1.1 [2024-05-13]
2
 
3
  ## Modified
 
1
+ version 0.2.0 [2024-05-14]
2
+
3
+ ## Added
4
+
5
+ - Introduced a comprehensive quiz functionality with LangGraph integration, enabling dynamic quiz generation and question answering based on Jupyter notebook content.
6
+ - Added new Python dependencies (`langgraph==0.0.48`) to support the quiz functionality and improved interaction flow.
7
+ - Implemented a new `graph.py` module to define the quiz and QA agents, along with the supervisor logic for managing conversation flow between agents.
8
+ - Enhanced the `chainlit_frontend.py` to integrate the LangGraph chain, facilitating real-time interaction with the quiz and QA functionality.
9
+ - Updated the `document_processing.py` and `retrieval.py` modules to support the new quiz functionality, including adjustments to the OpenAI model configuration and retrieval logic.
10
+
11
+ ## Modified
12
+
13
+ - Updated the OpenAI model used in `document_processing.py` from "gpt-4-turbo" to "gpt-4o" to improve the quality of document processing and retrieval.
14
+ - Refined the retrieval logic in `retrieval.py` to include a new method for initializing the RAG QA chain, enhancing the system's ability to provide accurate and contextually relevant answers.
15
+
16
  version 0.1.1 [2024-05-13]
17
 
18
  ## Modified
aims_tutor/chainlit_frontend.py CHANGED
@@ -2,6 +2,8 @@ import chainlit as cl
2
  from dotenv import load_dotenv
3
  from document_processing import DocumentManager
4
  from retrieval import RetrievalManager
 
 
5
 
6
  # Load environment variables
7
  load_dotenv()
@@ -28,6 +30,7 @@ async def start_chat():
28
  ).send()
29
 
30
  file = files[0] # Get the first file
 
31
  if file:
32
  notebook_path = file.path
33
  doc_manager = DocumentManager(notebook_path)
@@ -36,16 +39,51 @@ async def start_chat():
36
  cl.user_session.set("docs", doc_manager.get_documents())
37
  cl.user_session.set("retrieval_manager", RetrievalManager(doc_manager.get_retriever()))
38
 
 
 
 
 
 
 
39
  @cl.on_message
40
  async def main(message: cl.Message):
41
- # Retrieve the multi-query retriever from session
42
- retrieval_manager = cl.user_session.get("retrieval_manager")
43
- if not retrieval_manager:
 
44
  await cl.Message(content="No document processing setup found. Please upload a Jupyter notebook first.").send()
45
  return
46
 
47
- question = message.content
48
- response = retrieval_manager.notebook_QA(question) # Process the question
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
- msg = cl.Message(content=response)
51
- await msg.send()
 
2
  from dotenv import load_dotenv
3
  from document_processing import DocumentManager
4
  from retrieval import RetrievalManager
5
+ from langchain_core.messages import AIMessage, HumanMessage
6
+ from graph import create_aims_chain, AIMSState
7
 
8
  # Load environment variables
9
  load_dotenv()
 
30
  ).send()
31
 
32
  file = files[0] # Get the first file
33
+
34
  if file:
35
  notebook_path = file.path
36
  doc_manager = DocumentManager(notebook_path)
 
39
  cl.user_session.set("docs", doc_manager.get_documents())
40
  cl.user_session.set("retrieval_manager", RetrievalManager(doc_manager.get_retriever()))
41
 
42
+ # Initialize LangGraph chain with the retrieval chain
43
+ retrieval_chain = cl.user_session.get("retrieval_manager").get_RAG_QA_chain()
44
+ cl.user_session.set("retrieval_chain", retrieval_chain)
45
+ aims_chain = create_aims_chain(retrieval_chain)
46
+ cl.user_session.set("aims_chain", aims_chain)
47
+
48
  @cl.on_message
49
  async def main(message: cl.Message):
50
+ # Retrieve the LangGraph chain from the session
51
+ aims_chain = cl.user_session.get("aims_chain")
52
+
53
+ if not aims_chain:
54
  await cl.Message(content="No document processing setup found. Please upload a Jupyter notebook first.").send()
55
  return
56
 
57
+ # Create the initial state with the user message
58
+ user_message = message.content
59
+ state = AIMSState(messages=[HumanMessage(content=user_message)], next="supervisor", quiz=[], quiz_created=False, question_answered=False)
60
+
61
+
62
+ print(f"Initial state: {state}")
63
+
64
+ # Process the message through the LangGraph chain
65
+ for s in aims_chain.stream(state, {"recursion_limit": 10}):
66
+ print(f"State after processing: {s}")
67
+
68
+ # Extract messages from the state
69
+ if "__end__" not in s:
70
+ agent_state = next(iter(s.values()))
71
+ if "messages" in agent_state:
72
+ response = agent_state["messages"][-1].content
73
+ print(f"Response: {response}")
74
+ await cl.Message(content=response).send()
75
+ else:
76
+ print("Error: No messages found in agent state.")
77
+ else:
78
+ # Check if the quiz was created and send it to the frontend
79
+ if state["quiz_created"]:
80
+ quiz_message = state["messages"][-1].content
81
+ await cl.Message(content=quiz_message).send()
82
+ # Check if a question was answered and send the response to the frontend
83
+ if state["question_answered"]:
84
+ qa_message = state["messages"][-1].content
85
+ await cl.Message(content=qa_message).send()
86
+
87
+ print("Reached end state.")
88
 
89
+ break
 
aims_tutor/document_processing.py CHANGED
@@ -13,7 +13,7 @@ load_dotenv()
13
 
14
  # Configuration for OpenAI
15
  OPENAI_API_KEY = os.environ["OPENAI_API_KEY"]
16
- openai_chat_model = ChatOpenAI(model="gpt-4-turbo", temperature=0.1)
17
 
18
  class DocumentManager:
19
  """
@@ -82,7 +82,7 @@ class DocumentManager:
82
 
83
  qdrant_vectorstore = Qdrant.from_documents(split_chunks, embedding_model, location=":memory:", collection_name="Notebook")
84
 
85
- qdrant_retriever = qdrant_vectorstore.as_retriever() # Set the Qdrant vector store as a retriever
86
 
87
  multiquery_retriever = MultiQueryRetriever.from_llm(retriever=qdrant_retriever, llm=openai_chat_model, include_original=True) # Create a multi-query retriever on top of the Qdrant retriever
88
 
 
13
 
14
  # Configuration for OpenAI
15
  OPENAI_API_KEY = os.environ["OPENAI_API_KEY"]
16
+ openai_chat_model = ChatOpenAI(model="gpt-4o", temperature=0.1)
17
 
18
  class DocumentManager:
19
  """
 
82
 
83
  qdrant_vectorstore = Qdrant.from_documents(split_chunks, embedding_model, location=":memory:", collection_name="Notebook")
84
 
85
+ qdrant_retriever = qdrant_vectorstore.as_retriever()
86
 
87
  multiquery_retriever = MultiQueryRetriever.from_llm(retriever=qdrant_retriever, llm=openai_chat_model, include_original=True) # Create a multi-query retriever on top of the Qdrant retriever
88
 
aims_tutor/graph.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Annotated, List, TypedDict
2
+ from dotenv import load_dotenv
3
+ from langchain_core.tools import tool
4
+ from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
5
+ from langchain_core.messages import AIMessage, BaseMessage
6
+ from langchain.agents import AgentExecutor, create_openai_functions_agent
7
+ from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
8
+ from langchain_openai import ChatOpenAI
9
+ from langgraph.graph import END, StateGraph
10
+ import functools
11
+
12
+ # Load environment variables
13
+ load_dotenv()
14
+
15
+ # Instantiate the language model
16
+ llm = ChatOpenAI(model="gpt-4o")
17
+
18
+ class RetrievalChainWrapper:
19
+ def __init__(self, retrieval_chain):
20
+ self.retrieval_chain = retrieval_chain
21
+
22
+ def retrieve_information(
23
+ self,
24
+ query: Annotated[str, "query to ask the RAG tool"]
25
+ ):
26
+ """Use this tool to retrieve information about the provided notebook."""
27
+ response = self.retrieval_chain.invoke({"question": query})
28
+ return response["response"].content
29
+
30
+ # Create an instance of the wrapper
31
+ def get_retrieve_information_tool(retrieval_chain):
32
+ wrapper_instance = RetrievalChainWrapper(retrieval_chain)
33
+ return tool(wrapper_instance.retrieve_information)
34
+
35
+ # Function to create agents
36
+ def create_agent(
37
+ llm: ChatOpenAI,
38
+ tools: list,
39
+ system_prompt: str,
40
+ ) -> AgentExecutor:
41
+ """Create a function-calling agent and add it to the graph."""
42
+ system_prompt += "\nWork autonomously according to your specialty, using the tools available to you."
43
+ " Do not ask for clarification."
44
+ " Your other team members (and other teams) will collaborate with you with their own specialties."
45
+ " You are chosen for a reason! You are one of the following team members: {team_members}."
46
+ prompt = ChatPromptTemplate.from_messages(
47
+ [
48
+ (
49
+ "system",
50
+ system_prompt,
51
+ ),
52
+ MessagesPlaceholder(variable_name="messages"),
53
+ MessagesPlaceholder(variable_name="agent_scratchpad"),
54
+ ]
55
+ )
56
+ agent = create_openai_functions_agent(llm, tools, prompt)
57
+ executor = AgentExecutor(agent=agent, tools=tools, handle_parsing_errors=True)
58
+ return executor
59
+
60
+ # Function to create agent nodes
61
+ def agent_node(state, agent, name):
62
+ result = agent.invoke(state)
63
+ if 'messages' not in result: # Check if messages are present in the agent state
64
+ raise ValueError(f"No messages found in agent state: {result}")
65
+ new_state = {"messages": state["messages"] + [AIMessage(content=result["output"], name=name)]}
66
+ if "next" in result:
67
+ new_state["next"] = result["next"]
68
+ if name == "QuizAgent" and "quiz_created" in state and not state["quiz_created"]:
69
+ new_state["quiz_created"] = True
70
+ new_state["next"] = "FINISH" # Finish the conversation after the quiz is created and wait for a new user input
71
+ if name == "QAAgent":
72
+ new_state["question_answered"] = True
73
+ new_state["next"] = "question_answered"
74
+ return new_state
75
+
76
+
77
+ # Function to create the supervisor
78
+ def create_team_supervisor(llm: ChatOpenAI, system_prompt, members) -> AgentExecutor:
79
+ """An LLM-based router."""
80
+ options = ["WAIT", "FINISH"] + members
81
+ function_def = {
82
+ "name": "route",
83
+ "description": "Select the next role.",
84
+ "parameters": {
85
+ "title": "routeSchema",
86
+ "type": "object",
87
+ "properties": {
88
+ "next": {
89
+ "title": "Next",
90
+ "anyOf": [
91
+ {"enum": options},
92
+ ],
93
+ },
94
+ },
95
+ "required": ["next"],
96
+ },
97
+ }
98
+ prompt = ChatPromptTemplate.from_messages(
99
+ [
100
+ ("system", system_prompt),
101
+ MessagesPlaceholder(variable_name="messages"),
102
+ (
103
+ "system",
104
+ "Given the conversation above, who should act next?"
105
+ " Or should we WAIT for user input? Select one of: {options}",
106
+ ),
107
+ ]
108
+ ).partial(options=str(options), team_members=", ".join(members))
109
+ return (
110
+ prompt
111
+ | llm.bind_functions(functions=[function_def], function_call="route")
112
+ | JsonOutputFunctionsParser()
113
+ )
114
+
115
+ # Define the state for the system
116
+ class AIMSState(TypedDict):
117
+ messages: List[BaseMessage]
118
+ next: str
119
+ quiz: List[dict]
120
+ quiz_created: bool
121
+ question_answered: bool
122
+
123
+
124
+ # Create the LangGraph chain
125
+ def create_aims_chain(retrieval_chain):
126
+
127
+ retrieve_information_tool = get_retrieve_information_tool(retrieval_chain)
128
+
129
+ # Create QA Agent
130
+ qa_agent = create_agent(
131
+ llm,
132
+ [retrieve_information_tool],
133
+ "You are a QA assistant who answers questions about the provided notebook content.",
134
+ )
135
+
136
+ qa_node = functools.partial(agent_node, agent=qa_agent, name="QAAgent")
137
+
138
+ # Create Quiz Agent
139
+ quiz_agent = create_agent(
140
+ llm,
141
+ [retrieve_information_tool],
142
+ "You are a quiz creator that generates quizzes based on the provided notebook content."
143
+
144
+ """First, You MUST Use the retrieval_inforation_tool to gather context from the notebook to gather relevant and accurate information.
145
+
146
+ Next, create a 5-question quiz based on the information you have gathered. Include the answers at the end of the quiz.
147
+
148
+ Present the quiz to the user in a clear and concise manner."""
149
+ )
150
+
151
+ quiz_node = functools.partial(agent_node, agent=quiz_agent, name="QuizAgent")
152
+
153
+ # Create Supervisor Agent
154
+ supervisor_agent = create_team_supervisor(
155
+ llm,
156
+ "You are a supervisor tasked with managing a conversation between the following agents: QAAgent, QuizAgent. Given the user request, decide which agent should act next.",
157
+ ["QAAgent", "QuizAgent"],
158
+ )
159
+
160
+ # Build the LangGraph
161
+ aims_graph = StateGraph(AIMSState)
162
+ aims_graph.add_node("QAAgent", qa_node)
163
+ aims_graph.add_node("QuizAgent", quiz_node)
164
+ aims_graph.add_node("supervisor", supervisor_agent)
165
+
166
+ aims_graph.add_edge("QAAgent", "supervisor")
167
+ aims_graph.add_edge("QuizAgent", "supervisor")
168
+ aims_graph.add_conditional_edges(
169
+ "supervisor",
170
+ lambda x: "FINISH" if x.get("quiz_created") else ("FINISH" if x.get("question_answered") else x["next"]),
171
+ {"QAAgent": "QAAgent", "QuizAgent": "QuizAgent", "WAIT": END, "FINISH": END, "question_answered": END},
172
+ )
173
+
174
+ aims_graph.set_entry_point("supervisor")
175
+ return aims_graph.compile()
aims_tutor/retrieval.py CHANGED
@@ -42,3 +42,10 @@ class RetrievalManager:
42
  response = retrieval_augmented_qa_chain.invoke({"question": question})
43
 
44
  return response["response"].content
 
 
 
 
 
 
 
 
42
  response = retrieval_augmented_qa_chain.invoke({"question": question})
43
 
44
  return response["response"].content
45
+
46
+ def get_RAG_QA_chain(self):
47
+ return (
48
+ {"context": itemgetter("question") | self.retriever, "question": itemgetter("question")}
49
+ | RunnablePassthrough.assign(context=itemgetter("context"))
50
+ | {"response": self.prompts.get_rag_qa_prompt() | self.chat_model, "context": itemgetter("context")}
51
+ )
requirements.txt CHANGED
@@ -1,4 +1,5 @@
1
  langchain==0.1.20
 
2
  crewai==0.30.0
3
  qdrant-client==1.9.1
4
  python-dotenv==1.0.1
 
1
  langchain==0.1.20
2
+ langgraph==0.0.48
3
  crewai==0.30.0
4
  qdrant-client==1.9.1
5
  python-dotenv==1.0.1