Spaces:
Sleeping
Sleeping
Merge pull request #2 from JulsdL/quiz_functionnality
Browse filesImplementation of Quiz Functionality with LangGraph Integration
- CHANGELOG.md +15 -0
- aims_tutor/chainlit_frontend.py +45 -7
- aims_tutor/document_processing.py +2 -2
- aims_tutor/graph.py +175 -0
- aims_tutor/retrieval.py +7 -0
- requirements.txt +1 -0
CHANGELOG.md
CHANGED
@@ -1,3 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
version 0.1.1 [2024-05-13]
|
2 |
|
3 |
## Modified
|
|
|
1 |
+
version 0.2.0 [2024-05-14]
|
2 |
+
|
3 |
+
## Added
|
4 |
+
|
5 |
+
- Introduced a comprehensive quiz functionality with LangGraph integration, enabling dynamic quiz generation and question answering based on Jupyter notebook content.
|
6 |
+
- Added new Python dependencies (`langgraph==0.0.48`) to support the quiz functionality and improved interaction flow.
|
7 |
+
- Implemented a new `graph.py` module to define the quiz and QA agents, along with the supervisor logic for managing conversation flow between agents.
|
8 |
+
- Enhanced the `chainlit_frontend.py` to integrate the LangGraph chain, facilitating real-time interaction with the quiz and QA functionality.
|
9 |
+
- Updated the `document_processing.py` and `retrieval.py` modules to support the new quiz functionality, including adjustments to the OpenAI model configuration and retrieval logic.
|
10 |
+
|
11 |
+
## Modified
|
12 |
+
|
13 |
+
- Updated the OpenAI model used in `document_processing.py` from "gpt-4-turbo" to "gpt-4o" to improve the quality of document processing and retrieval.
|
14 |
+
- Refined the retrieval logic in `retrieval.py` to include a new method for initializing the RAG QA chain, enhancing the system's ability to provide accurate and contextually relevant answers.
|
15 |
+
|
16 |
version 0.1.1 [2024-05-13]
|
17 |
|
18 |
## Modified
|
aims_tutor/chainlit_frontend.py
CHANGED
@@ -2,6 +2,8 @@ import chainlit as cl
|
|
2 |
from dotenv import load_dotenv
|
3 |
from document_processing import DocumentManager
|
4 |
from retrieval import RetrievalManager
|
|
|
|
|
5 |
|
6 |
# Load environment variables
|
7 |
load_dotenv()
|
@@ -28,6 +30,7 @@ async def start_chat():
|
|
28 |
).send()
|
29 |
|
30 |
file = files[0] # Get the first file
|
|
|
31 |
if file:
|
32 |
notebook_path = file.path
|
33 |
doc_manager = DocumentManager(notebook_path)
|
@@ -36,16 +39,51 @@ async def start_chat():
|
|
36 |
cl.user_session.set("docs", doc_manager.get_documents())
|
37 |
cl.user_session.set("retrieval_manager", RetrievalManager(doc_manager.get_retriever()))
|
38 |
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
@cl.on_message
|
40 |
async def main(message: cl.Message):
|
41 |
-
# Retrieve the
|
42 |
-
|
43 |
-
|
|
|
44 |
await cl.Message(content="No document processing setup found. Please upload a Jupyter notebook first.").send()
|
45 |
return
|
46 |
|
47 |
-
|
48 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
|
50 |
-
|
51 |
-
await msg.send()
|
|
|
2 |
from dotenv import load_dotenv
|
3 |
from document_processing import DocumentManager
|
4 |
from retrieval import RetrievalManager
|
5 |
+
from langchain_core.messages import AIMessage, HumanMessage
|
6 |
+
from graph import create_aims_chain, AIMSState
|
7 |
|
8 |
# Load environment variables
|
9 |
load_dotenv()
|
|
|
30 |
).send()
|
31 |
|
32 |
file = files[0] # Get the first file
|
33 |
+
|
34 |
if file:
|
35 |
notebook_path = file.path
|
36 |
doc_manager = DocumentManager(notebook_path)
|
|
|
39 |
cl.user_session.set("docs", doc_manager.get_documents())
|
40 |
cl.user_session.set("retrieval_manager", RetrievalManager(doc_manager.get_retriever()))
|
41 |
|
42 |
+
# Initialize LangGraph chain with the retrieval chain
|
43 |
+
retrieval_chain = cl.user_session.get("retrieval_manager").get_RAG_QA_chain()
|
44 |
+
cl.user_session.set("retrieval_chain", retrieval_chain)
|
45 |
+
aims_chain = create_aims_chain(retrieval_chain)
|
46 |
+
cl.user_session.set("aims_chain", aims_chain)
|
47 |
+
|
48 |
@cl.on_message
|
49 |
async def main(message: cl.Message):
|
50 |
+
# Retrieve the LangGraph chain from the session
|
51 |
+
aims_chain = cl.user_session.get("aims_chain")
|
52 |
+
|
53 |
+
if not aims_chain:
|
54 |
await cl.Message(content="No document processing setup found. Please upload a Jupyter notebook first.").send()
|
55 |
return
|
56 |
|
57 |
+
# Create the initial state with the user message
|
58 |
+
user_message = message.content
|
59 |
+
state = AIMSState(messages=[HumanMessage(content=user_message)], next="supervisor", quiz=[], quiz_created=False, question_answered=False)
|
60 |
+
|
61 |
+
|
62 |
+
print(f"Initial state: {state}")
|
63 |
+
|
64 |
+
# Process the message through the LangGraph chain
|
65 |
+
for s in aims_chain.stream(state, {"recursion_limit": 10}):
|
66 |
+
print(f"State after processing: {s}")
|
67 |
+
|
68 |
+
# Extract messages from the state
|
69 |
+
if "__end__" not in s:
|
70 |
+
agent_state = next(iter(s.values()))
|
71 |
+
if "messages" in agent_state:
|
72 |
+
response = agent_state["messages"][-1].content
|
73 |
+
print(f"Response: {response}")
|
74 |
+
await cl.Message(content=response).send()
|
75 |
+
else:
|
76 |
+
print("Error: No messages found in agent state.")
|
77 |
+
else:
|
78 |
+
# Check if the quiz was created and send it to the frontend
|
79 |
+
if state["quiz_created"]:
|
80 |
+
quiz_message = state["messages"][-1].content
|
81 |
+
await cl.Message(content=quiz_message).send()
|
82 |
+
# Check if a question was answered and send the response to the frontend
|
83 |
+
if state["question_answered"]:
|
84 |
+
qa_message = state["messages"][-1].content
|
85 |
+
await cl.Message(content=qa_message).send()
|
86 |
+
|
87 |
+
print("Reached end state.")
|
88 |
|
89 |
+
break
|
|
aims_tutor/document_processing.py
CHANGED
@@ -13,7 +13,7 @@ load_dotenv()
|
|
13 |
|
14 |
# Configuration for OpenAI
|
15 |
OPENAI_API_KEY = os.environ["OPENAI_API_KEY"]
|
16 |
-
openai_chat_model = ChatOpenAI(model="gpt-
|
17 |
|
18 |
class DocumentManager:
|
19 |
"""
|
@@ -82,7 +82,7 @@ class DocumentManager:
|
|
82 |
|
83 |
qdrant_vectorstore = Qdrant.from_documents(split_chunks, embedding_model, location=":memory:", collection_name="Notebook")
|
84 |
|
85 |
-
qdrant_retriever = qdrant_vectorstore.as_retriever()
|
86 |
|
87 |
multiquery_retriever = MultiQueryRetriever.from_llm(retriever=qdrant_retriever, llm=openai_chat_model, include_original=True) # Create a multi-query retriever on top of the Qdrant retriever
|
88 |
|
|
|
13 |
|
14 |
# Configuration for OpenAI
|
15 |
OPENAI_API_KEY = os.environ["OPENAI_API_KEY"]
|
16 |
+
openai_chat_model = ChatOpenAI(model="gpt-4o", temperature=0.1)
|
17 |
|
18 |
class DocumentManager:
|
19 |
"""
|
|
|
82 |
|
83 |
qdrant_vectorstore = Qdrant.from_documents(split_chunks, embedding_model, location=":memory:", collection_name="Notebook")
|
84 |
|
85 |
+
qdrant_retriever = qdrant_vectorstore.as_retriever()
|
86 |
|
87 |
multiquery_retriever = MultiQueryRetriever.from_llm(retriever=qdrant_retriever, llm=openai_chat_model, include_original=True) # Create a multi-query retriever on top of the Qdrant retriever
|
88 |
|
aims_tutor/graph.py
ADDED
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Annotated, List, TypedDict
|
2 |
+
from dotenv import load_dotenv
|
3 |
+
from langchain_core.tools import tool
|
4 |
+
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
5 |
+
from langchain_core.messages import AIMessage, BaseMessage
|
6 |
+
from langchain.agents import AgentExecutor, create_openai_functions_agent
|
7 |
+
from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
|
8 |
+
from langchain_openai import ChatOpenAI
|
9 |
+
from langgraph.graph import END, StateGraph
|
10 |
+
import functools
|
11 |
+
|
12 |
+
# Load environment variables
|
13 |
+
load_dotenv()
|
14 |
+
|
15 |
+
# Instantiate the language model
|
16 |
+
llm = ChatOpenAI(model="gpt-4o")
|
17 |
+
|
18 |
+
class RetrievalChainWrapper:
|
19 |
+
def __init__(self, retrieval_chain):
|
20 |
+
self.retrieval_chain = retrieval_chain
|
21 |
+
|
22 |
+
def retrieve_information(
|
23 |
+
self,
|
24 |
+
query: Annotated[str, "query to ask the RAG tool"]
|
25 |
+
):
|
26 |
+
"""Use this tool to retrieve information about the provided notebook."""
|
27 |
+
response = self.retrieval_chain.invoke({"question": query})
|
28 |
+
return response["response"].content
|
29 |
+
|
30 |
+
# Create an instance of the wrapper
|
31 |
+
def get_retrieve_information_tool(retrieval_chain):
|
32 |
+
wrapper_instance = RetrievalChainWrapper(retrieval_chain)
|
33 |
+
return tool(wrapper_instance.retrieve_information)
|
34 |
+
|
35 |
+
# Function to create agents
|
36 |
+
def create_agent(
|
37 |
+
llm: ChatOpenAI,
|
38 |
+
tools: list,
|
39 |
+
system_prompt: str,
|
40 |
+
) -> AgentExecutor:
|
41 |
+
"""Create a function-calling agent and add it to the graph."""
|
42 |
+
system_prompt += "\nWork autonomously according to your specialty, using the tools available to you."
|
43 |
+
" Do not ask for clarification."
|
44 |
+
" Your other team members (and other teams) will collaborate with you with their own specialties."
|
45 |
+
" You are chosen for a reason! You are one of the following team members: {team_members}."
|
46 |
+
prompt = ChatPromptTemplate.from_messages(
|
47 |
+
[
|
48 |
+
(
|
49 |
+
"system",
|
50 |
+
system_prompt,
|
51 |
+
),
|
52 |
+
MessagesPlaceholder(variable_name="messages"),
|
53 |
+
MessagesPlaceholder(variable_name="agent_scratchpad"),
|
54 |
+
]
|
55 |
+
)
|
56 |
+
agent = create_openai_functions_agent(llm, tools, prompt)
|
57 |
+
executor = AgentExecutor(agent=agent, tools=tools, handle_parsing_errors=True)
|
58 |
+
return executor
|
59 |
+
|
60 |
+
# Function to create agent nodes
|
61 |
+
def agent_node(state, agent, name):
|
62 |
+
result = agent.invoke(state)
|
63 |
+
if 'messages' not in result: # Check if messages are present in the agent state
|
64 |
+
raise ValueError(f"No messages found in agent state: {result}")
|
65 |
+
new_state = {"messages": state["messages"] + [AIMessage(content=result["output"], name=name)]}
|
66 |
+
if "next" in result:
|
67 |
+
new_state["next"] = result["next"]
|
68 |
+
if name == "QuizAgent" and "quiz_created" in state and not state["quiz_created"]:
|
69 |
+
new_state["quiz_created"] = True
|
70 |
+
new_state["next"] = "FINISH" # Finish the conversation after the quiz is created and wait for a new user input
|
71 |
+
if name == "QAAgent":
|
72 |
+
new_state["question_answered"] = True
|
73 |
+
new_state["next"] = "question_answered"
|
74 |
+
return new_state
|
75 |
+
|
76 |
+
|
77 |
+
# Function to create the supervisor
|
78 |
+
def create_team_supervisor(llm: ChatOpenAI, system_prompt, members) -> AgentExecutor:
|
79 |
+
"""An LLM-based router."""
|
80 |
+
options = ["WAIT", "FINISH"] + members
|
81 |
+
function_def = {
|
82 |
+
"name": "route",
|
83 |
+
"description": "Select the next role.",
|
84 |
+
"parameters": {
|
85 |
+
"title": "routeSchema",
|
86 |
+
"type": "object",
|
87 |
+
"properties": {
|
88 |
+
"next": {
|
89 |
+
"title": "Next",
|
90 |
+
"anyOf": [
|
91 |
+
{"enum": options},
|
92 |
+
],
|
93 |
+
},
|
94 |
+
},
|
95 |
+
"required": ["next"],
|
96 |
+
},
|
97 |
+
}
|
98 |
+
prompt = ChatPromptTemplate.from_messages(
|
99 |
+
[
|
100 |
+
("system", system_prompt),
|
101 |
+
MessagesPlaceholder(variable_name="messages"),
|
102 |
+
(
|
103 |
+
"system",
|
104 |
+
"Given the conversation above, who should act next?"
|
105 |
+
" Or should we WAIT for user input? Select one of: {options}",
|
106 |
+
),
|
107 |
+
]
|
108 |
+
).partial(options=str(options), team_members=", ".join(members))
|
109 |
+
return (
|
110 |
+
prompt
|
111 |
+
| llm.bind_functions(functions=[function_def], function_call="route")
|
112 |
+
| JsonOutputFunctionsParser()
|
113 |
+
)
|
114 |
+
|
115 |
+
# Define the state for the system
|
116 |
+
class AIMSState(TypedDict):
|
117 |
+
messages: List[BaseMessage]
|
118 |
+
next: str
|
119 |
+
quiz: List[dict]
|
120 |
+
quiz_created: bool
|
121 |
+
question_answered: bool
|
122 |
+
|
123 |
+
|
124 |
+
# Create the LangGraph chain
|
125 |
+
def create_aims_chain(retrieval_chain):
|
126 |
+
|
127 |
+
retrieve_information_tool = get_retrieve_information_tool(retrieval_chain)
|
128 |
+
|
129 |
+
# Create QA Agent
|
130 |
+
qa_agent = create_agent(
|
131 |
+
llm,
|
132 |
+
[retrieve_information_tool],
|
133 |
+
"You are a QA assistant who answers questions about the provided notebook content.",
|
134 |
+
)
|
135 |
+
|
136 |
+
qa_node = functools.partial(agent_node, agent=qa_agent, name="QAAgent")
|
137 |
+
|
138 |
+
# Create Quiz Agent
|
139 |
+
quiz_agent = create_agent(
|
140 |
+
llm,
|
141 |
+
[retrieve_information_tool],
|
142 |
+
"You are a quiz creator that generates quizzes based on the provided notebook content."
|
143 |
+
|
144 |
+
"""First, You MUST Use the retrieval_inforation_tool to gather context from the notebook to gather relevant and accurate information.
|
145 |
+
|
146 |
+
Next, create a 5-question quiz based on the information you have gathered. Include the answers at the end of the quiz.
|
147 |
+
|
148 |
+
Present the quiz to the user in a clear and concise manner."""
|
149 |
+
)
|
150 |
+
|
151 |
+
quiz_node = functools.partial(agent_node, agent=quiz_agent, name="QuizAgent")
|
152 |
+
|
153 |
+
# Create Supervisor Agent
|
154 |
+
supervisor_agent = create_team_supervisor(
|
155 |
+
llm,
|
156 |
+
"You are a supervisor tasked with managing a conversation between the following agents: QAAgent, QuizAgent. Given the user request, decide which agent should act next.",
|
157 |
+
["QAAgent", "QuizAgent"],
|
158 |
+
)
|
159 |
+
|
160 |
+
# Build the LangGraph
|
161 |
+
aims_graph = StateGraph(AIMSState)
|
162 |
+
aims_graph.add_node("QAAgent", qa_node)
|
163 |
+
aims_graph.add_node("QuizAgent", quiz_node)
|
164 |
+
aims_graph.add_node("supervisor", supervisor_agent)
|
165 |
+
|
166 |
+
aims_graph.add_edge("QAAgent", "supervisor")
|
167 |
+
aims_graph.add_edge("QuizAgent", "supervisor")
|
168 |
+
aims_graph.add_conditional_edges(
|
169 |
+
"supervisor",
|
170 |
+
lambda x: "FINISH" if x.get("quiz_created") else ("FINISH" if x.get("question_answered") else x["next"]),
|
171 |
+
{"QAAgent": "QAAgent", "QuizAgent": "QuizAgent", "WAIT": END, "FINISH": END, "question_answered": END},
|
172 |
+
)
|
173 |
+
|
174 |
+
aims_graph.set_entry_point("supervisor")
|
175 |
+
return aims_graph.compile()
|
aims_tutor/retrieval.py
CHANGED
@@ -42,3 +42,10 @@ class RetrievalManager:
|
|
42 |
response = retrieval_augmented_qa_chain.invoke({"question": question})
|
43 |
|
44 |
return response["response"].content
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
response = retrieval_augmented_qa_chain.invoke({"question": question})
|
43 |
|
44 |
return response["response"].content
|
45 |
+
|
46 |
+
def get_RAG_QA_chain(self):
|
47 |
+
return (
|
48 |
+
{"context": itemgetter("question") | self.retriever, "question": itemgetter("question")}
|
49 |
+
| RunnablePassthrough.assign(context=itemgetter("context"))
|
50 |
+
| {"response": self.prompts.get_rag_qa_prompt() | self.chat_model, "context": itemgetter("context")}
|
51 |
+
)
|
requirements.txt
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
langchain==0.1.20
|
|
|
2 |
crewai==0.30.0
|
3 |
qdrant-client==1.9.1
|
4 |
python-dotenv==1.0.1
|
|
|
1 |
langchain==0.1.20
|
2 |
+
langgraph==0.0.48
|
3 |
crewai==0.30.0
|
4 |
qdrant-client==1.9.1
|
5 |
python-dotenv==1.0.1
|