JulsdL commited on
Commit
e4f4515
·
1 Parent(s): 85e07b4

Added a test.py with a LangGraph implementation with a supervisor, a QA Agent and a Quiz Agent

Browse files
Files changed (2) hide show
  1. aims_tutor/test.py +169 -0
  2. requirements.txt +1 -0
aims_tutor/test.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Annotated, List, TypedDict
2
+ from dotenv import load_dotenv
3
+ from langchain_core.tools import tool
4
+ from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
5
+ from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
6
+ from langchain.agents import AgentExecutor, create_openai_functions_agent
7
+ from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
8
+ from langchain_core.runnables import RunnablePassthrough
9
+ from langchain_openai import ChatOpenAI
10
+ from langgraph.graph import END, StateGraph
11
+ import functools
12
+
13
+ # Load environment variables
14
+ load_dotenv()
15
+
16
+ @tool
17
+ def generate_quiz(
18
+ documents: Annotated[List[str], "List of documents to generate quiz from"],
19
+ num_questions: Annotated[int, "Number of questions to generate"] = 5
20
+ ) -> Annotated[List[dict], "List of quiz questions"]:
21
+ """Generate a quiz based on the provided documents."""
22
+ # Placeholder logic for quiz generation
23
+ # In a real scenario, you'd use NLP techniques to generate questions
24
+ questions = [{"question": f"Question {i+1}", "options": ["Option 1", "Option 2", "Option 3"], "answer": "Option 1"} for i in range(num_questions)]
25
+ return questions
26
+
27
+ @tool
28
+ def retrieve_information(
29
+ query: Annotated[str, "query to ask the retrieve information tool"]
30
+ ):
31
+ """Use Retrieval Augmented Generation to retrieve information about the provided content."""
32
+ return {"response": "This is a placeholder response for retrieval information."}
33
+
34
+
35
+ # Define a function to create agents
36
+ def create_agent(
37
+ llm: ChatOpenAI,
38
+ tools: list,
39
+ system_prompt: str,
40
+ ) -> AgentExecutor:
41
+ """Create a function-calling agent and add it to the graph."""
42
+ system_prompt += "\nWork autonomously according to your specialty, using the tools available to you."
43
+ " Do not ask for clarification."
44
+ " Your other team members (and other teams) will collaborate with you with their own specialties."
45
+ " You are chosen for a reason! You are one of the following team members: {team_members}."
46
+ prompt = ChatPromptTemplate.from_messages(
47
+ [
48
+ (
49
+ "system",
50
+ system_prompt,
51
+ ),
52
+ MessagesPlaceholder(variable_name="messages"),
53
+ MessagesPlaceholder(variable_name="agent_scratchpad"),
54
+ ]
55
+ )
56
+ agent = create_openai_functions_agent(llm, tools, prompt)
57
+ executor = AgentExecutor(agent=agent, tools=tools)
58
+ return executor
59
+
60
+ # Define a function to create agent nodes
61
+ def agent_node(state, agent, name):
62
+ result = agent.invoke(state)
63
+ return {"messages": state["messages"] + [AIMessage(content=result["output"], name=name)]}
64
+
65
+ # Define a function to create the supervisor
66
+ def create_team_supervisor(llm: ChatOpenAI, system_prompt, members) -> AgentExecutor:
67
+ """An LLM-based router."""
68
+ options = ["WAIT", "FINISH"] + members
69
+ function_def = {
70
+ "name": "route",
71
+ "description": "Select the next role.",
72
+ "parameters": {
73
+ "title": "routeSchema",
74
+ "type": "object",
75
+ "properties": {
76
+ "next": {
77
+ "title": "Next",
78
+ "anyOf": [
79
+ {"enum": options},
80
+ ],
81
+ },
82
+ },
83
+ "required": ["next"],
84
+ },
85
+ }
86
+ prompt = ChatPromptTemplate.from_messages(
87
+ [
88
+ ("system", system_prompt),
89
+ MessagesPlaceholder(variable_name="messages"),
90
+ (
91
+ "system",
92
+ "Given the conversation above, who should act next?"
93
+ " Or should we WAIT for user input? Select one of: {options}",
94
+ ),
95
+ ]
96
+ ).partial(options=str(options), team_members=", ".join(members))
97
+ return (
98
+ prompt
99
+ | llm.bind_functions(functions=[function_def], function_call="route")
100
+ | JsonOutputFunctionsParser()
101
+ )
102
+
103
+ # Define the state for the system
104
+ class AIMSState(TypedDict):
105
+ messages: List[BaseMessage]
106
+ next: str
107
+ quiz: List[dict]
108
+
109
+
110
+ # Instantiate the language model
111
+ llm = ChatOpenAI(model="gpt-4o")
112
+
113
+ # Create QA Agent
114
+ qa_agent = create_agent(
115
+ llm,
116
+ [retrieve_information], # Existing QA tool
117
+ "You are a QA assistant who answers questions about the provided notebook content.",
118
+ )
119
+ qa_node = functools.partial(agent_node, agent=qa_agent, name="QAAgent")
120
+
121
+ # Create Quiz Agent
122
+ quiz_agent = create_agent(
123
+ llm,
124
+ [generate_quiz],
125
+ "You are a quiz creator that generates quizzes based on the provided notebook content.",
126
+ )
127
+ quiz_node = functools.partial(agent_node, agent=quiz_agent, name="QuizAgent")
128
+
129
+ # Create Supervisor Agent
130
+ supervisor_agent = create_team_supervisor(
131
+ llm,
132
+ "You are a supervisor tasked with managing a conversation between the following agents: QAAgent, QuizAgent. Given the user request, decide which agent should act next.",
133
+ ["QAAgent", "QuizAgent"],
134
+ )
135
+
136
+ # Build the LangGraph
137
+ aims_graph = StateGraph(AIMSState)
138
+ aims_graph.add_node("QAAgent", qa_node)
139
+ aims_graph.add_node("QuizAgent", quiz_node)
140
+ aims_graph.add_node("supervisor", supervisor_agent)
141
+
142
+ aims_graph.add_edge("QAAgent", "supervisor")
143
+ aims_graph.add_edge("QuizAgent", "supervisor")
144
+ aims_graph.add_conditional_edges(
145
+ "supervisor",
146
+ lambda x: x["next"],
147
+ {"QAAgent": "QAAgent", "QuizAgent": "QuizAgent", "WAIT": END, "FINISH": END},
148
+ )
149
+
150
+ aims_graph.set_entry_point("supervisor")
151
+ chain = aims_graph.compile()
152
+
153
+ if __name__ == "__main__":
154
+
155
+ # Define the function to enter the chain
156
+ def enter_chain(message: str):
157
+ results = {
158
+ "messages": [HumanMessage(content="I'd like to take a quiz based on the uploaded notebook.")],
159
+ }
160
+ return results
161
+
162
+ aims_chain = enter_chain | chain
163
+
164
+ for s in aims_chain.stream(
165
+ "I'd like to take a quiz based on the uploaded notebook.", {"recursion_limit": 15}
166
+ ):
167
+ if "__end__" not in s:
168
+ print(s)
169
+ print("---")
requirements.txt CHANGED
@@ -1,4 +1,5 @@
1
  langchain==0.1.20
 
2
  crewai==0.30.0
3
  qdrant-client==1.9.1
4
  python-dotenv==1.0.1
 
1
  langchain==0.1.20
2
+ langgraph==0.0.48
3
  crewai==0.30.0
4
  qdrant-client==1.9.1
5
  python-dotenv==1.0.1