Spaces:
Sleeping
Sleeping
Added a test.py with a LangGraph implementation with a supervisor, a QA Agent and a Quiz Agent
Browse files- aims_tutor/test.py +169 -0
- requirements.txt +1 -0
aims_tutor/test.py
ADDED
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Annotated, List, TypedDict
|
2 |
+
from dotenv import load_dotenv
|
3 |
+
from langchain_core.tools import tool
|
4 |
+
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
5 |
+
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
|
6 |
+
from langchain.agents import AgentExecutor, create_openai_functions_agent
|
7 |
+
from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
|
8 |
+
from langchain_core.runnables import RunnablePassthrough
|
9 |
+
from langchain_openai import ChatOpenAI
|
10 |
+
from langgraph.graph import END, StateGraph
|
11 |
+
import functools
|
12 |
+
|
13 |
+
# Load environment variables
|
14 |
+
load_dotenv()
|
15 |
+
|
16 |
+
@tool
|
17 |
+
def generate_quiz(
|
18 |
+
documents: Annotated[List[str], "List of documents to generate quiz from"],
|
19 |
+
num_questions: Annotated[int, "Number of questions to generate"] = 5
|
20 |
+
) -> Annotated[List[dict], "List of quiz questions"]:
|
21 |
+
"""Generate a quiz based on the provided documents."""
|
22 |
+
# Placeholder logic for quiz generation
|
23 |
+
# In a real scenario, you'd use NLP techniques to generate questions
|
24 |
+
questions = [{"question": f"Question {i+1}", "options": ["Option 1", "Option 2", "Option 3"], "answer": "Option 1"} for i in range(num_questions)]
|
25 |
+
return questions
|
26 |
+
|
27 |
+
@tool
|
28 |
+
def retrieve_information(
|
29 |
+
query: Annotated[str, "query to ask the retrieve information tool"]
|
30 |
+
):
|
31 |
+
"""Use Retrieval Augmented Generation to retrieve information about the provided content."""
|
32 |
+
return {"response": "This is a placeholder response for retrieval information."}
|
33 |
+
|
34 |
+
|
35 |
+
# Define a function to create agents
|
36 |
+
def create_agent(
|
37 |
+
llm: ChatOpenAI,
|
38 |
+
tools: list,
|
39 |
+
system_prompt: str,
|
40 |
+
) -> AgentExecutor:
|
41 |
+
"""Create a function-calling agent and add it to the graph."""
|
42 |
+
system_prompt += "\nWork autonomously according to your specialty, using the tools available to you."
|
43 |
+
" Do not ask for clarification."
|
44 |
+
" Your other team members (and other teams) will collaborate with you with their own specialties."
|
45 |
+
" You are chosen for a reason! You are one of the following team members: {team_members}."
|
46 |
+
prompt = ChatPromptTemplate.from_messages(
|
47 |
+
[
|
48 |
+
(
|
49 |
+
"system",
|
50 |
+
system_prompt,
|
51 |
+
),
|
52 |
+
MessagesPlaceholder(variable_name="messages"),
|
53 |
+
MessagesPlaceholder(variable_name="agent_scratchpad"),
|
54 |
+
]
|
55 |
+
)
|
56 |
+
agent = create_openai_functions_agent(llm, tools, prompt)
|
57 |
+
executor = AgentExecutor(agent=agent, tools=tools)
|
58 |
+
return executor
|
59 |
+
|
60 |
+
# Define a function to create agent nodes
|
61 |
+
def agent_node(state, agent, name):
|
62 |
+
result = agent.invoke(state)
|
63 |
+
return {"messages": state["messages"] + [AIMessage(content=result["output"], name=name)]}
|
64 |
+
|
65 |
+
# Define a function to create the supervisor
|
66 |
+
def create_team_supervisor(llm: ChatOpenAI, system_prompt, members) -> AgentExecutor:
|
67 |
+
"""An LLM-based router."""
|
68 |
+
options = ["WAIT", "FINISH"] + members
|
69 |
+
function_def = {
|
70 |
+
"name": "route",
|
71 |
+
"description": "Select the next role.",
|
72 |
+
"parameters": {
|
73 |
+
"title": "routeSchema",
|
74 |
+
"type": "object",
|
75 |
+
"properties": {
|
76 |
+
"next": {
|
77 |
+
"title": "Next",
|
78 |
+
"anyOf": [
|
79 |
+
{"enum": options},
|
80 |
+
],
|
81 |
+
},
|
82 |
+
},
|
83 |
+
"required": ["next"],
|
84 |
+
},
|
85 |
+
}
|
86 |
+
prompt = ChatPromptTemplate.from_messages(
|
87 |
+
[
|
88 |
+
("system", system_prompt),
|
89 |
+
MessagesPlaceholder(variable_name="messages"),
|
90 |
+
(
|
91 |
+
"system",
|
92 |
+
"Given the conversation above, who should act next?"
|
93 |
+
" Or should we WAIT for user input? Select one of: {options}",
|
94 |
+
),
|
95 |
+
]
|
96 |
+
).partial(options=str(options), team_members=", ".join(members))
|
97 |
+
return (
|
98 |
+
prompt
|
99 |
+
| llm.bind_functions(functions=[function_def], function_call="route")
|
100 |
+
| JsonOutputFunctionsParser()
|
101 |
+
)
|
102 |
+
|
103 |
+
# Define the state for the system
|
104 |
+
class AIMSState(TypedDict):
|
105 |
+
messages: List[BaseMessage]
|
106 |
+
next: str
|
107 |
+
quiz: List[dict]
|
108 |
+
|
109 |
+
|
110 |
+
# Instantiate the language model
|
111 |
+
llm = ChatOpenAI(model="gpt-4o")
|
112 |
+
|
113 |
+
# Create QA Agent
|
114 |
+
qa_agent = create_agent(
|
115 |
+
llm,
|
116 |
+
[retrieve_information], # Existing QA tool
|
117 |
+
"You are a QA assistant who answers questions about the provided notebook content.",
|
118 |
+
)
|
119 |
+
qa_node = functools.partial(agent_node, agent=qa_agent, name="QAAgent")
|
120 |
+
|
121 |
+
# Create Quiz Agent
|
122 |
+
quiz_agent = create_agent(
|
123 |
+
llm,
|
124 |
+
[generate_quiz],
|
125 |
+
"You are a quiz creator that generates quizzes based on the provided notebook content.",
|
126 |
+
)
|
127 |
+
quiz_node = functools.partial(agent_node, agent=quiz_agent, name="QuizAgent")
|
128 |
+
|
129 |
+
# Create Supervisor Agent
|
130 |
+
supervisor_agent = create_team_supervisor(
|
131 |
+
llm,
|
132 |
+
"You are a supervisor tasked with managing a conversation between the following agents: QAAgent, QuizAgent. Given the user request, decide which agent should act next.",
|
133 |
+
["QAAgent", "QuizAgent"],
|
134 |
+
)
|
135 |
+
|
136 |
+
# Build the LangGraph
|
137 |
+
aims_graph = StateGraph(AIMSState)
|
138 |
+
aims_graph.add_node("QAAgent", qa_node)
|
139 |
+
aims_graph.add_node("QuizAgent", quiz_node)
|
140 |
+
aims_graph.add_node("supervisor", supervisor_agent)
|
141 |
+
|
142 |
+
aims_graph.add_edge("QAAgent", "supervisor")
|
143 |
+
aims_graph.add_edge("QuizAgent", "supervisor")
|
144 |
+
aims_graph.add_conditional_edges(
|
145 |
+
"supervisor",
|
146 |
+
lambda x: x["next"],
|
147 |
+
{"QAAgent": "QAAgent", "QuizAgent": "QuizAgent", "WAIT": END, "FINISH": END},
|
148 |
+
)
|
149 |
+
|
150 |
+
aims_graph.set_entry_point("supervisor")
|
151 |
+
chain = aims_graph.compile()
|
152 |
+
|
153 |
+
if __name__ == "__main__":
|
154 |
+
|
155 |
+
# Define the function to enter the chain
|
156 |
+
def enter_chain(message: str):
|
157 |
+
results = {
|
158 |
+
"messages": [HumanMessage(content="I'd like to take a quiz based on the uploaded notebook.")],
|
159 |
+
}
|
160 |
+
return results
|
161 |
+
|
162 |
+
aims_chain = enter_chain | chain
|
163 |
+
|
164 |
+
for s in aims_chain.stream(
|
165 |
+
"I'd like to take a quiz based on the uploaded notebook.", {"recursion_limit": 15}
|
166 |
+
):
|
167 |
+
if "__end__" not in s:
|
168 |
+
print(s)
|
169 |
+
print("---")
|
requirements.txt
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
langchain==0.1.20
|
|
|
2 |
crewai==0.30.0
|
3 |
qdrant-client==1.9.1
|
4 |
python-dotenv==1.0.1
|
|
|
1 |
langchain==0.1.20
|
2 |
+
langgraph==0.0.48
|
3 |
crewai==0.30.0
|
4 |
qdrant-client==1.9.1
|
5 |
python-dotenv==1.0.1
|