JulsdL commited on
Commit
deeba11
·
1 Parent(s): e4f4515

Integrate LangGraph chain for AIMS quiz functionality and extend retrieval.py with RAG QA chain

Browse files

- Added AIMessage and HumanMessage imports to support message processing in chainlit_frontend.py.
- Implemented AIMSState and create_aims_chain function to initialize LangGraph chain with retrieval capabilities.
- Modified main message handler in chainlit_frontend.py to utilize the LangGraph chain for processing user messages and generating responses.
- Extended retrieval.py with get_RAG_QA_chain method to facilitate retrieval-augmented question answering within the AIMS quiz functionality.

aims_tutor/chainlit_frontend.py CHANGED
@@ -2,6 +2,8 @@ import chainlit as cl
2
  from dotenv import load_dotenv
3
  from document_processing import DocumentManager
4
  from retrieval import RetrievalManager
 
 
5
 
6
  # Load environment variables
7
  load_dotenv()
@@ -36,16 +38,40 @@ async def start_chat():
36
  cl.user_session.set("docs", doc_manager.get_documents())
37
  cl.user_session.set("retrieval_manager", RetrievalManager(doc_manager.get_retriever()))
38
 
 
 
 
 
 
 
39
  @cl.on_message
40
  async def main(message: cl.Message):
41
- # Retrieve the multi-query retriever from session
42
- retrieval_manager = cl.user_session.get("retrieval_manager")
43
- if not retrieval_manager:
 
44
  await cl.Message(content="No document processing setup found. Please upload a Jupyter notebook first.").send()
45
  return
46
 
47
- question = message.content
48
- response = retrieval_manager.notebook_QA(question) # Process the question
 
 
 
 
 
 
 
49
 
50
- msg = cl.Message(content=response)
51
- await msg.send()
 
 
 
 
 
 
 
 
 
 
 
2
  from dotenv import load_dotenv
3
  from document_processing import DocumentManager
4
  from retrieval import RetrievalManager
5
+ from langchain_core.messages import AIMessage, HumanMessage
6
+ from graph import create_aims_chain, AIMSState
7
 
8
  # Load environment variables
9
  load_dotenv()
 
38
  cl.user_session.set("docs", doc_manager.get_documents())
39
  cl.user_session.set("retrieval_manager", RetrievalManager(doc_manager.get_retriever()))
40
 
41
+ # Initialize LangGraph chain with the retrieval chain
42
+ retrieval_chain = cl.user_session.get("retrieval_manager").get_RAG_QA_chain()
43
+ cl.user_session.set("retrieval_chain", retrieval_chain) # Store the retrieval chain in the session
44
+ aims_chain = create_aims_chain(retrieval_chain)
45
+ cl.user_session.set("aims_chain", aims_chain)
46
+
47
  @cl.on_message
48
  async def main(message: cl.Message):
49
+ # Retrieve the LangGraph chain from the session
50
+ aims_chain = cl.user_session.get("aims_chain")
51
+
52
+ if not aims_chain:
53
  await cl.Message(content="No document processing setup found. Please upload a Jupyter notebook first.").send()
54
  return
55
 
56
+ # Create the initial state with the user message
57
+ user_message = message.content
58
+ state = AIMSState(messages=[HumanMessage(content=user_message)], next="supervisor", quiz=[])
59
+
60
+ print(f"Initial state: {state}")
61
+
62
+ # Process the message through the LangGraph chain
63
+ for s in aims_chain.stream(state, {"recursion_limit": 10}):
64
+ print(f"State after processing: {s}")
65
 
66
+ # Extract messages from the state
67
+ if "__end__" not in s:
68
+ agent_state = next(iter(s.values()))
69
+ if "messages" in agent_state:
70
+ response = agent_state["messages"][-1].content
71
+ print(f"Response: {response}")
72
+ await cl.Message(content=response).send()
73
+ else:
74
+ print("Error: No messages found in agent state.")
75
+ else:
76
+ print("Reached end state.")
77
+ break
aims_tutor/{test.py → graph.py} RENAMED
@@ -9,10 +9,31 @@ from langchain_core.runnables import RunnablePassthrough
9
  from langchain_openai import ChatOpenAI
10
  from langgraph.graph import END, StateGraph
11
  import functools
 
12
 
13
  # Load environment variables
14
  load_dotenv()
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  @tool
17
  def generate_quiz(
18
  documents: Annotated[List[str], "List of documents to generate quiz from"],
@@ -24,14 +45,6 @@ def generate_quiz(
24
  questions = [{"question": f"Question {i+1}", "options": ["Option 1", "Option 2", "Option 3"], "answer": "Option 1"} for i in range(num_questions)]
25
  return questions
26
 
27
- @tool
28
- def retrieve_information(
29
- query: Annotated[str, "query to ask the retrieve information tool"]
30
- ):
31
- """Use Retrieval Augmented Generation to retrieve information about the provided content."""
32
- return {"response": "This is a placeholder response for retrieval information."}
33
-
34
-
35
  # Define a function to create agents
36
  def create_agent(
37
  llm: ChatOpenAI,
@@ -107,63 +120,48 @@ class AIMSState(TypedDict):
107
  quiz: List[dict]
108
 
109
 
110
- # Instantiate the language model
111
- llm = ChatOpenAI(model="gpt-4o")
112
 
113
- # Create QA Agent
114
- qa_agent = create_agent(
115
- llm,
116
- [retrieve_information], # Existing QA tool
117
- "You are a QA assistant who answers questions about the provided notebook content.",
118
- )
119
- qa_node = functools.partial(agent_node, agent=qa_agent, name="QAAgent")
120
-
121
- # Create Quiz Agent
122
- quiz_agent = create_agent(
123
- llm,
124
- [generate_quiz],
125
- "You are a quiz creator that generates quizzes based on the provided notebook content.",
126
- )
127
- quiz_node = functools.partial(agent_node, agent=quiz_agent, name="QuizAgent")
128
-
129
- # Create Supervisor Agent
130
- supervisor_agent = create_team_supervisor(
131
- llm,
132
- "You are a supervisor tasked with managing a conversation between the following agents: QAAgent, QuizAgent. Given the user request, decide which agent should act next.",
133
- ["QAAgent", "QuizAgent"],
134
- )
135
-
136
- # Build the LangGraph
137
- aims_graph = StateGraph(AIMSState)
138
- aims_graph.add_node("QAAgent", qa_node)
139
- aims_graph.add_node("QuizAgent", quiz_node)
140
- aims_graph.add_node("supervisor", supervisor_agent)
141
-
142
- aims_graph.add_edge("QAAgent", "supervisor")
143
- aims_graph.add_edge("QuizAgent", "supervisor")
144
- aims_graph.add_conditional_edges(
145
- "supervisor",
146
- lambda x: x["next"],
147
- {"QAAgent": "QAAgent", "QuizAgent": "QuizAgent", "WAIT": END, "FINISH": END},
148
- )
149
-
150
- aims_graph.set_entry_point("supervisor")
151
- chain = aims_graph.compile()
152
-
153
- if __name__ == "__main__":
154
-
155
- # Define the function to enter the chain
156
- def enter_chain(message: str):
157
- results = {
158
- "messages": [HumanMessage(content="I'd like to take a quiz based on the uploaded notebook.")],
159
- }
160
- return results
161
-
162
- aims_chain = enter_chain | chain
163
-
164
- for s in aims_chain.stream(
165
- "I'd like to take a quiz based on the uploaded notebook.", {"recursion_limit": 15}
166
- ):
167
- if "__end__" not in s:
168
- print(s)
169
- print("---")
 
9
  from langchain_openai import ChatOpenAI
10
  from langgraph.graph import END, StateGraph
11
  import functools
12
+ from retrieval import RetrievalManager
13
 
14
  # Load environment variables
15
  load_dotenv()
16
 
17
+ # Instantiate the language model
18
+ llm = ChatOpenAI(model="gpt-4o")
19
+
20
+ class RetrievalChainWrapper:
21
+ def __init__(self, retrieval_chain):
22
+ self.retrieval_chain = retrieval_chain
23
+
24
+ def retrieve_information(
25
+ self,
26
+ query: Annotated[str, "query to ask the RAG tool"]
27
+ ):
28
+ """Use this tool to retrieve information about the provided notebook."""
29
+ response = self.retrieval_chain.invoke({"question": query})
30
+ return response["response"].content
31
+
32
+ # Create an instance of the wrapper
33
+ def get_retrieve_information_tool(retrieval_chain):
34
+ wrapper_instance = RetrievalChainWrapper(retrieval_chain)
35
+ return tool(wrapper_instance.retrieve_information)
36
+
37
  @tool
38
  def generate_quiz(
39
  documents: Annotated[List[str], "List of documents to generate quiz from"],
 
45
  questions = [{"question": f"Question {i+1}", "options": ["Option 1", "Option 2", "Option 3"], "answer": "Option 1"} for i in range(num_questions)]
46
  return questions
47
 
 
 
 
 
 
 
 
 
48
  # Define a function to create agents
49
  def create_agent(
50
  llm: ChatOpenAI,
 
120
  quiz: List[dict]
121
 
122
 
123
+ # Create the LangGraph chain
124
+ def create_aims_chain(retrieval_chain):
125
 
126
+ retrieve_information_tool = get_retrieve_information_tool(retrieval_chain)
127
+
128
+ # Create QA Agent
129
+ qa_agent = create_agent(
130
+ llm,
131
+ [retrieve_information_tool],
132
+ "You are a QA assistant who answers questions about the provided notebook content.",
133
+ )
134
+
135
+ qa_node = functools.partial(agent_node, agent=qa_agent, name="QAAgent")
136
+
137
+ # Create Quiz Agent
138
+ quiz_agent = create_agent(
139
+ llm,
140
+ [generate_quiz, retrieve_information_tool],
141
+ "You are a quiz creator that generates quizzes based on the provided notebook content.",
142
+ )
143
+ quiz_node = functools.partial(agent_node, agent=quiz_agent, name="QuizAgent")
144
+
145
+ # Create Supervisor Agent
146
+ supervisor_agent = create_team_supervisor(
147
+ llm,
148
+ "You are a supervisor tasked with managing a conversation between the following agents: QAAgent, QuizAgent. Given the user request, decide which agent should act next.",
149
+ ["QAAgent", "QuizAgent"],
150
+ )
151
+
152
+ # Build the LangGraph
153
+ aims_graph = StateGraph(AIMSState)
154
+ aims_graph.add_node("QAAgent", qa_node)
155
+ aims_graph.add_node("QuizAgent", quiz_node)
156
+ aims_graph.add_node("supervisor", supervisor_agent)
157
+
158
+ aims_graph.add_edge("QAAgent", "supervisor")
159
+ aims_graph.add_edge("QuizAgent", "supervisor")
160
+ aims_graph.add_conditional_edges(
161
+ "supervisor",
162
+ lambda x: x["next"],
163
+ {"QAAgent": "QAAgent", "QuizAgent": "QuizAgent", "WAIT": END, "FINISH": END},
164
+ )
165
+
166
+ aims_graph.set_entry_point("supervisor")
167
+ return aims_graph.compile()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aims_tutor/retrieval.py CHANGED
@@ -42,3 +42,10 @@ class RetrievalManager:
42
  response = retrieval_augmented_qa_chain.invoke({"question": question})
43
 
44
  return response["response"].content
 
 
 
 
 
 
 
 
42
  response = retrieval_augmented_qa_chain.invoke({"question": question})
43
 
44
  return response["response"].content
45
+
46
+ def get_RAG_QA_chain(self):
47
+ return (
48
+ {"context": itemgetter("question") | self.retriever, "question": itemgetter("question")}
49
+ | RunnablePassthrough.assign(context=itemgetter("context"))
50
+ | {"response": self.prompts.get_rag_qa_prompt() | self.chat_model, "context": itemgetter("context")}
51
+ )