Spaces:
Sleeping
Sleeping
Refactor quiz functionality and improve agent state management
Browse files- Removed unused imports and generate_quiz tool to streamline quiz creation process.
- Enhanced agent_node function to handle new quiz and question answering logic, including error handling for missing messages and state updates for quiz creation and question answering.
- Updated create_team_supervisor and agent_node functions to support new quiz creation flow and state management.
- Added quiz_created and question_answered flags to AIMSState to track the state of quiz interaction and question answering.
- Modified chainlit_frontend.py to initialize new state flags and handle frontend messaging for quiz creation and question answering.
- Simplified conditional edge logic in LangGraph chain to accommodate new state flags and improve readability.
- aims_tutor/chainlit_frontend.py +14 -2
- aims_tutor/graph.py +27 -18
aims_tutor/chainlit_frontend.py
CHANGED
@@ -30,6 +30,7 @@ async def start_chat():
|
|
30 |
).send()
|
31 |
|
32 |
file = files[0] # Get the first file
|
|
|
33 |
if file:
|
34 |
notebook_path = file.path
|
35 |
doc_manager = DocumentManager(notebook_path)
|
@@ -40,7 +41,7 @@ async def start_chat():
|
|
40 |
|
41 |
# Initialize LangGraph chain with the retrieval chain
|
42 |
retrieval_chain = cl.user_session.get("retrieval_manager").get_RAG_QA_chain()
|
43 |
-
cl.user_session.set("retrieval_chain", retrieval_chain)
|
44 |
aims_chain = create_aims_chain(retrieval_chain)
|
45 |
cl.user_session.set("aims_chain", aims_chain)
|
46 |
|
@@ -55,7 +56,8 @@ async def main(message: cl.Message):
|
|
55 |
|
56 |
# Create the initial state with the user message
|
57 |
user_message = message.content
|
58 |
-
state = AIMSState(messages=[HumanMessage(content=user_message)], next="supervisor", quiz=[])
|
|
|
59 |
|
60 |
print(f"Initial state: {state}")
|
61 |
|
@@ -73,5 +75,15 @@ async def main(message: cl.Message):
|
|
73 |
else:
|
74 |
print("Error: No messages found in agent state.")
|
75 |
else:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
print("Reached end state.")
|
|
|
77 |
break
|
|
|
30 |
).send()
|
31 |
|
32 |
file = files[0] # Get the first file
|
33 |
+
|
34 |
if file:
|
35 |
notebook_path = file.path
|
36 |
doc_manager = DocumentManager(notebook_path)
|
|
|
41 |
|
42 |
# Initialize LangGraph chain with the retrieval chain
|
43 |
retrieval_chain = cl.user_session.get("retrieval_manager").get_RAG_QA_chain()
|
44 |
+
cl.user_session.set("retrieval_chain", retrieval_chain)
|
45 |
aims_chain = create_aims_chain(retrieval_chain)
|
46 |
cl.user_session.set("aims_chain", aims_chain)
|
47 |
|
|
|
56 |
|
57 |
# Create the initial state with the user message
|
58 |
user_message = message.content
|
59 |
+
state = AIMSState(messages=[HumanMessage(content=user_message)], next="supervisor", quiz=[], quiz_created=False, question_answered=False)
|
60 |
+
|
61 |
|
62 |
print(f"Initial state: {state}")
|
63 |
|
|
|
75 |
else:
|
76 |
print("Error: No messages found in agent state.")
|
77 |
else:
|
78 |
+
# Check if the quiz was created and send it to the frontend
|
79 |
+
if state["quiz_created"]:
|
80 |
+
quiz_message = state["messages"][-1].content
|
81 |
+
await cl.Message(content=quiz_message).send()
|
82 |
+
# Check if a question was answered and send the response to the frontend
|
83 |
+
if state["question_answered"]:
|
84 |
+
qa_message = state["messages"][-1].content
|
85 |
+
await cl.Message(content=qa_message).send()
|
86 |
+
|
87 |
print("Reached end state.")
|
88 |
+
|
89 |
break
|
aims_tutor/graph.py
CHANGED
@@ -2,14 +2,12 @@ from typing import Annotated, List, TypedDict
|
|
2 |
from dotenv import load_dotenv
|
3 |
from langchain_core.tools import tool
|
4 |
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
5 |
-
from langchain_core.messages import AIMessage, BaseMessage
|
6 |
from langchain.agents import AgentExecutor, create_openai_functions_agent
|
7 |
from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
|
8 |
-
from langchain_core.runnables import RunnablePassthrough
|
9 |
from langchain_openai import ChatOpenAI
|
10 |
from langgraph.graph import END, StateGraph
|
11 |
import functools
|
12 |
-
from retrieval import RetrievalManager
|
13 |
|
14 |
# Load environment variables
|
15 |
load_dotenv()
|
@@ -34,15 +32,6 @@ def get_retrieve_information_tool(retrieval_chain):
|
|
34 |
wrapper_instance = RetrievalChainWrapper(retrieval_chain)
|
35 |
return tool(wrapper_instance.retrieve_information)
|
36 |
|
37 |
-
@tool
|
38 |
-
def generate_quiz(
|
39 |
-
documents: Annotated[List[str], "List of documents to generate quiz from"],
|
40 |
-
num_questions: Annotated[int, "Number of questions to generate"] = 5
|
41 |
-
) -> Annotated[List[dict], "List of quiz questions"]:
|
42 |
-
"""Generate a quiz based on the provided documents."""
|
43 |
-
questions = [{"question": f"Question {i+1}", "options": ["Option 1", "Option 2", "Option 3"], "answer": "Option 1"} for i in range(num_questions)]
|
44 |
-
return questions
|
45 |
-
|
46 |
# Function to create agents
|
47 |
def create_agent(
|
48 |
llm: ChatOpenAI,
|
@@ -65,13 +54,25 @@ def create_agent(
|
|
65 |
]
|
66 |
)
|
67 |
agent = create_openai_functions_agent(llm, tools, prompt)
|
68 |
-
executor = AgentExecutor(agent=agent, tools=tools)
|
69 |
return executor
|
70 |
|
71 |
# Function to create agent nodes
|
72 |
def agent_node(state, agent, name):
|
73 |
result = agent.invoke(state)
|
74 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
75 |
|
76 |
# Function to create the supervisor
|
77 |
def create_team_supervisor(llm: ChatOpenAI, system_prompt, members) -> AgentExecutor:
|
@@ -116,6 +117,8 @@ class AIMSState(TypedDict):
|
|
116 |
messages: List[BaseMessage]
|
117 |
next: str
|
118 |
quiz: List[dict]
|
|
|
|
|
119 |
|
120 |
|
121 |
# Create the LangGraph chain
|
@@ -135,8 +138,14 @@ def create_aims_chain(retrieval_chain):
|
|
135 |
# Create Quiz Agent
|
136 |
quiz_agent = create_agent(
|
137 |
llm,
|
138 |
-
[
|
139 |
-
"You are a quiz creator that generates quizzes based on the provided notebook content.
|
|
|
|
|
|
|
|
|
|
|
|
|
140 |
)
|
141 |
|
142 |
quiz_node = functools.partial(agent_node, agent=quiz_agent, name="QuizAgent")
|
@@ -158,8 +167,8 @@ def create_aims_chain(retrieval_chain):
|
|
158 |
aims_graph.add_edge("QuizAgent", "supervisor")
|
159 |
aims_graph.add_conditional_edges(
|
160 |
"supervisor",
|
161 |
-
lambda x: x["next"],
|
162 |
-
{"QAAgent": "QAAgent", "QuizAgent": "QuizAgent", "WAIT": END, "FINISH": END},
|
163 |
)
|
164 |
|
165 |
aims_graph.set_entry_point("supervisor")
|
|
|
2 |
from dotenv import load_dotenv
|
3 |
from langchain_core.tools import tool
|
4 |
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
5 |
+
from langchain_core.messages import AIMessage, BaseMessage
|
6 |
from langchain.agents import AgentExecutor, create_openai_functions_agent
|
7 |
from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
|
|
|
8 |
from langchain_openai import ChatOpenAI
|
9 |
from langgraph.graph import END, StateGraph
|
10 |
import functools
|
|
|
11 |
|
12 |
# Load environment variables
|
13 |
load_dotenv()
|
|
|
32 |
wrapper_instance = RetrievalChainWrapper(retrieval_chain)
|
33 |
return tool(wrapper_instance.retrieve_information)
|
34 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
# Function to create agents
|
36 |
def create_agent(
|
37 |
llm: ChatOpenAI,
|
|
|
54 |
]
|
55 |
)
|
56 |
agent = create_openai_functions_agent(llm, tools, prompt)
|
57 |
+
executor = AgentExecutor(agent=agent, tools=tools, handle_parsing_errors=True)
|
58 |
return executor
|
59 |
|
60 |
# Function to create agent nodes
|
61 |
def agent_node(state, agent, name):
|
62 |
result = agent.invoke(state)
|
63 |
+
if 'messages' not in result: # Check if messages are present in the agent state
|
64 |
+
raise ValueError(f"No messages found in agent state: {result}")
|
65 |
+
new_state = {"messages": state["messages"] + [AIMessage(content=result["output"], name=name)]}
|
66 |
+
if "next" in result:
|
67 |
+
new_state["next"] = result["next"]
|
68 |
+
if name == "QuizAgent" and "quiz_created" in state and not state["quiz_created"]:
|
69 |
+
new_state["quiz_created"] = True
|
70 |
+
new_state["next"] = "FINISH" # Finish the conversation after the quiz is created and wait for a new user input
|
71 |
+
if name == "QAAgent":
|
72 |
+
new_state["question_answered"] = True
|
73 |
+
new_state["next"] = "question_answered"
|
74 |
+
return new_state
|
75 |
+
|
76 |
|
77 |
# Function to create the supervisor
|
78 |
def create_team_supervisor(llm: ChatOpenAI, system_prompt, members) -> AgentExecutor:
|
|
|
117 |
messages: List[BaseMessage]
|
118 |
next: str
|
119 |
quiz: List[dict]
|
120 |
+
quiz_created: bool
|
121 |
+
question_answered: bool
|
122 |
|
123 |
|
124 |
# Create the LangGraph chain
|
|
|
138 |
# Create Quiz Agent
|
139 |
quiz_agent = create_agent(
|
140 |
llm,
|
141 |
+
[retrieve_information_tool],
|
142 |
+
"You are a quiz creator that generates quizzes based on the provided notebook content."
|
143 |
+
|
144 |
+
"""First, You MUST Use the retrieval_inforation_tool to gather context from the notebook to gather relevant and accurate information.
|
145 |
+
|
146 |
+
Next, create a 5-question quiz based on the information you have gathered. Include the answers at the end of the quiz.
|
147 |
+
|
148 |
+
Present the quiz to the user in a clear and concise manner."""
|
149 |
)
|
150 |
|
151 |
quiz_node = functools.partial(agent_node, agent=quiz_agent, name="QuizAgent")
|
|
|
167 |
aims_graph.add_edge("QuizAgent", "supervisor")
|
168 |
aims_graph.add_conditional_edges(
|
169 |
"supervisor",
|
170 |
+
lambda x: "FINISH" if x.get("quiz_created") else ("FINISH" if x.get("question_answered") else x["next"]),
|
171 |
+
{"QAAgent": "QAAgent", "QuizAgent": "QuizAgent", "WAIT": END, "FINISH": END, "question_answered": END},
|
172 |
)
|
173 |
|
174 |
aims_graph.set_entry_point("supervisor")
|