JulsdL commited on
Commit
4e8b6ab
·
1 Parent(s): ead288d

Refactor quiz functionality and improve agent state management

Browse files

- Removed unused imports and generate_quiz tool to streamline quiz creation process.
- Enhanced agent_node function to handle new quiz and question answering logic, including error handling for missing messages and state updates for quiz creation and question answering.
- Updated create_team_supervisor and agent_node functions to support new quiz creation flow and state management.
- Added quiz_created and question_answered flags to AIMSState to track the state of quiz interaction and question answering.
- Modified chainlit_frontend.py to initialize new state flags and handle frontend messaging for quiz creation and question answering.
- Simplified conditional edge logic in LangGraph chain to accommodate new state flags and improve readability.

aims_tutor/chainlit_frontend.py CHANGED
@@ -30,6 +30,7 @@ async def start_chat():
30
  ).send()
31
 
32
  file = files[0] # Get the first file
 
33
  if file:
34
  notebook_path = file.path
35
  doc_manager = DocumentManager(notebook_path)
@@ -40,7 +41,7 @@ async def start_chat():
40
 
41
  # Initialize LangGraph chain with the retrieval chain
42
  retrieval_chain = cl.user_session.get("retrieval_manager").get_RAG_QA_chain()
43
- cl.user_session.set("retrieval_chain", retrieval_chain) # Store the retrieval chain in the session
44
  aims_chain = create_aims_chain(retrieval_chain)
45
  cl.user_session.set("aims_chain", aims_chain)
46
 
@@ -55,7 +56,8 @@ async def main(message: cl.Message):
55
 
56
  # Create the initial state with the user message
57
  user_message = message.content
58
- state = AIMSState(messages=[HumanMessage(content=user_message)], next="supervisor", quiz=[])
 
59
 
60
  print(f"Initial state: {state}")
61
 
@@ -73,5 +75,15 @@ async def main(message: cl.Message):
73
  else:
74
  print("Error: No messages found in agent state.")
75
  else:
 
 
 
 
 
 
 
 
 
76
  print("Reached end state.")
 
77
  break
 
30
  ).send()
31
 
32
  file = files[0] # Get the first file
33
+
34
  if file:
35
  notebook_path = file.path
36
  doc_manager = DocumentManager(notebook_path)
 
41
 
42
  # Initialize LangGraph chain with the retrieval chain
43
  retrieval_chain = cl.user_session.get("retrieval_manager").get_RAG_QA_chain()
44
+ cl.user_session.set("retrieval_chain", retrieval_chain)
45
  aims_chain = create_aims_chain(retrieval_chain)
46
  cl.user_session.set("aims_chain", aims_chain)
47
 
 
56
 
57
  # Create the initial state with the user message
58
  user_message = message.content
59
+ state = AIMSState(messages=[HumanMessage(content=user_message)], next="supervisor", quiz=[], quiz_created=False, question_answered=False)
60
+
61
 
62
  print(f"Initial state: {state}")
63
 
 
75
  else:
76
  print("Error: No messages found in agent state.")
77
  else:
78
+ # Check if the quiz was created and send it to the frontend
79
+ if state["quiz_created"]:
80
+ quiz_message = state["messages"][-1].content
81
+ await cl.Message(content=quiz_message).send()
82
+ # Check if a question was answered and send the response to the frontend
83
+ if state["question_answered"]:
84
+ qa_message = state["messages"][-1].content
85
+ await cl.Message(content=qa_message).send()
86
+
87
  print("Reached end state.")
88
+
89
  break
aims_tutor/graph.py CHANGED
@@ -2,14 +2,12 @@ from typing import Annotated, List, TypedDict
2
  from dotenv import load_dotenv
3
  from langchain_core.tools import tool
4
  from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
5
- from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
6
  from langchain.agents import AgentExecutor, create_openai_functions_agent
7
  from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
8
- from langchain_core.runnables import RunnablePassthrough
9
  from langchain_openai import ChatOpenAI
10
  from langgraph.graph import END, StateGraph
11
  import functools
12
- from retrieval import RetrievalManager
13
 
14
  # Load environment variables
15
  load_dotenv()
@@ -34,15 +32,6 @@ def get_retrieve_information_tool(retrieval_chain):
34
  wrapper_instance = RetrievalChainWrapper(retrieval_chain)
35
  return tool(wrapper_instance.retrieve_information)
36
 
37
- @tool
38
- def generate_quiz(
39
- documents: Annotated[List[str], "List of documents to generate quiz from"],
40
- num_questions: Annotated[int, "Number of questions to generate"] = 5
41
- ) -> Annotated[List[dict], "List of quiz questions"]:
42
- """Generate a quiz based on the provided documents."""
43
- questions = [{"question": f"Question {i+1}", "options": ["Option 1", "Option 2", "Option 3"], "answer": "Option 1"} for i in range(num_questions)]
44
- return questions
45
-
46
  # Function to create agents
47
  def create_agent(
48
  llm: ChatOpenAI,
@@ -65,13 +54,25 @@ def create_agent(
65
  ]
66
  )
67
  agent = create_openai_functions_agent(llm, tools, prompt)
68
- executor = AgentExecutor(agent=agent, tools=tools)
69
  return executor
70
 
71
  # Function to create agent nodes
72
  def agent_node(state, agent, name):
73
  result = agent.invoke(state)
74
- return {"messages": state["messages"] + [AIMessage(content=result["output"], name=name)]}
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
  # Function to create the supervisor
77
  def create_team_supervisor(llm: ChatOpenAI, system_prompt, members) -> AgentExecutor:
@@ -116,6 +117,8 @@ class AIMSState(TypedDict):
116
  messages: List[BaseMessage]
117
  next: str
118
  quiz: List[dict]
 
 
119
 
120
 
121
  # Create the LangGraph chain
@@ -135,8 +138,14 @@ def create_aims_chain(retrieval_chain):
135
  # Create Quiz Agent
136
  quiz_agent = create_agent(
137
  llm,
138
- [generate_quiz, retrieve_information_tool],
139
- "You are a quiz creator that generates quizzes based on the provided notebook content. Use the retrieval tool to gather context if needed.",
 
 
 
 
 
 
140
  )
141
 
142
  quiz_node = functools.partial(agent_node, agent=quiz_agent, name="QuizAgent")
@@ -158,8 +167,8 @@ def create_aims_chain(retrieval_chain):
158
  aims_graph.add_edge("QuizAgent", "supervisor")
159
  aims_graph.add_conditional_edges(
160
  "supervisor",
161
- lambda x: x["next"],
162
- {"QAAgent": "QAAgent", "QuizAgent": "QuizAgent", "WAIT": END, "FINISH": END},
163
  )
164
 
165
  aims_graph.set_entry_point("supervisor")
 
2
  from dotenv import load_dotenv
3
  from langchain_core.tools import tool
4
  from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
5
+ from langchain_core.messages import AIMessage, BaseMessage
6
  from langchain.agents import AgentExecutor, create_openai_functions_agent
7
  from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
 
8
  from langchain_openai import ChatOpenAI
9
  from langgraph.graph import END, StateGraph
10
  import functools
 
11
 
12
  # Load environment variables
13
  load_dotenv()
 
32
  wrapper_instance = RetrievalChainWrapper(retrieval_chain)
33
  return tool(wrapper_instance.retrieve_information)
34
 
 
 
 
 
 
 
 
 
 
35
  # Function to create agents
36
  def create_agent(
37
  llm: ChatOpenAI,
 
54
  ]
55
  )
56
  agent = create_openai_functions_agent(llm, tools, prompt)
57
+ executor = AgentExecutor(agent=agent, tools=tools, handle_parsing_errors=True)
58
  return executor
59
 
60
  # Function to create agent nodes
61
  def agent_node(state, agent, name):
62
  result = agent.invoke(state)
63
+ if 'messages' not in result: # Check if messages are present in the agent state
64
+ raise ValueError(f"No messages found in agent state: {result}")
65
+ new_state = {"messages": state["messages"] + [AIMessage(content=result["output"], name=name)]}
66
+ if "next" in result:
67
+ new_state["next"] = result["next"]
68
+ if name == "QuizAgent" and "quiz_created" in state and not state["quiz_created"]:
69
+ new_state["quiz_created"] = True
70
+ new_state["next"] = "FINISH" # Finish the conversation after the quiz is created and wait for a new user input
71
+ if name == "QAAgent":
72
+ new_state["question_answered"] = True
73
+ new_state["next"] = "question_answered"
74
+ return new_state
75
+
76
 
77
  # Function to create the supervisor
78
  def create_team_supervisor(llm: ChatOpenAI, system_prompt, members) -> AgentExecutor:
 
117
  messages: List[BaseMessage]
118
  next: str
119
  quiz: List[dict]
120
+ quiz_created: bool
121
+ question_answered: bool
122
 
123
 
124
  # Create the LangGraph chain
 
138
  # Create Quiz Agent
139
  quiz_agent = create_agent(
140
  llm,
141
+ [retrieve_information_tool],
142
+ "You are a quiz creator that generates quizzes based on the provided notebook content."
143
+
144
+ """First, You MUST Use the retrieval_inforation_tool to gather context from the notebook to gather relevant and accurate information.
145
+
146
+ Next, create a 5-question quiz based on the information you have gathered. Include the answers at the end of the quiz.
147
+
148
+ Present the quiz to the user in a clear and concise manner."""
149
  )
150
 
151
  quiz_node = functools.partial(agent_node, agent=quiz_agent, name="QuizAgent")
 
167
  aims_graph.add_edge("QuizAgent", "supervisor")
168
  aims_graph.add_conditional_edges(
169
  "supervisor",
170
+ lambda x: "FINISH" if x.get("quiz_created") else ("FINISH" if x.get("question_answered") else x["next"]),
171
+ {"QAAgent": "QAAgent", "QuizAgent": "QuizAgent", "WAIT": END, "FINISH": END, "question_answered": END},
172
  )
173
 
174
  aims_graph.set_entry_point("supervisor")