File size: 5,951 Bytes
e4f4515
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
from typing import Annotated, List, TypedDict
from dotenv import load_dotenv
from langchain_core.tools import tool
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
from langchain.agents import AgentExecutor, create_openai_functions_agent
from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
from langchain_core.runnables import RunnablePassthrough
from langchain_openai import ChatOpenAI
from langgraph.graph import END, StateGraph
import functools

# Load environment variables
load_dotenv()

@tool
def generate_quiz(
    documents: Annotated[List[str], "List of documents to generate quiz from"],
    num_questions: Annotated[int, "Number of questions to generate"] = 5
) -> Annotated[List[dict], "List of quiz questions"]:
    """Generate a quiz based on the provided documents."""
    # Placeholder logic for quiz generation
    # In a real scenario, you'd use NLP techniques to generate questions
    questions = [{"question": f"Question {i+1}", "options": ["Option 1", "Option 2", "Option 3"], "answer": "Option 1"} for i in range(num_questions)]
    return questions

@tool
def retrieve_information(
    query: Annotated[str, "query to ask the retrieve information tool"]
    ):
    """Use Retrieval Augmented Generation to retrieve information about the provided content."""
    return {"response": "This is a placeholder response for retrieval information."}


# Define a function to create agents
def create_agent(
    llm: ChatOpenAI,
    tools: list,
    system_prompt: str,
) -> AgentExecutor:
    """Create a function-calling agent and add it to the graph."""
    system_prompt += "\nWork autonomously according to your specialty, using the tools available to you."
    " Do not ask for clarification."
    " Your other team members (and other teams) will collaborate with you with their own specialties."
    " You are chosen for a reason! You are one of the following team members: {team_members}."
    prompt = ChatPromptTemplate.from_messages(
        [
            (
                "system",
                system_prompt,
            ),
            MessagesPlaceholder(variable_name="messages"),
            MessagesPlaceholder(variable_name="agent_scratchpad"),
        ]
    )
    agent = create_openai_functions_agent(llm, tools, prompt)
    executor = AgentExecutor(agent=agent, tools=tools)
    return executor

# Define a function to create agent nodes
def agent_node(state, agent, name):
    result = agent.invoke(state)
    return {"messages": state["messages"] + [AIMessage(content=result["output"], name=name)]}

# Define a function to create the supervisor
def create_team_supervisor(llm: ChatOpenAI, system_prompt, members) -> AgentExecutor:
    """An LLM-based router."""
    options = ["WAIT", "FINISH"] + members
    function_def = {
        "name": "route",
        "description": "Select the next role.",
        "parameters": {
            "title": "routeSchema",
            "type": "object",
            "properties": {
                "next": {
                    "title": "Next",
                    "anyOf": [
                        {"enum": options},
                    ],
                },
            },
            "required": ["next"],
        },
    }
    prompt = ChatPromptTemplate.from_messages(
        [
            ("system", system_prompt),
            MessagesPlaceholder(variable_name="messages"),
            (
                "system",
                "Given the conversation above, who should act next?"
                " Or should we WAIT for user input? Select one of: {options}",
            ),
        ]
    ).partial(options=str(options), team_members=", ".join(members))
    return (
        prompt
        | llm.bind_functions(functions=[function_def], function_call="route")
        | JsonOutputFunctionsParser()
    )

# Define the state for the system
class AIMSState(TypedDict):
    messages: List[BaseMessage]
    next: str
    quiz: List[dict]


# Instantiate the language model
llm = ChatOpenAI(model="gpt-4o")

# Create QA Agent
qa_agent = create_agent(
    llm,
    [retrieve_information],  # Existing QA tool
    "You are a QA assistant who answers questions about the provided notebook content.",
)
qa_node = functools.partial(agent_node, agent=qa_agent, name="QAAgent")

# Create Quiz Agent
quiz_agent = create_agent(
    llm,
    [generate_quiz],
    "You are a quiz creator that generates quizzes based on the provided notebook content.",
)
quiz_node = functools.partial(agent_node, agent=quiz_agent, name="QuizAgent")

# Create Supervisor Agent
supervisor_agent = create_team_supervisor(
    llm,
    "You are a supervisor tasked with managing a conversation between the following agents: QAAgent, QuizAgent. Given the user request, decide which agent should act next.",
    ["QAAgent", "QuizAgent"],
)

# Build the LangGraph
aims_graph = StateGraph(AIMSState)
aims_graph.add_node("QAAgent", qa_node)
aims_graph.add_node("QuizAgent", quiz_node)
aims_graph.add_node("supervisor", supervisor_agent)

aims_graph.add_edge("QAAgent", "supervisor")
aims_graph.add_edge("QuizAgent", "supervisor")
aims_graph.add_conditional_edges(
    "supervisor",
    lambda x: x["next"],
    {"QAAgent": "QAAgent", "QuizAgent": "QuizAgent", "WAIT": END, "FINISH": END},
)

aims_graph.set_entry_point("supervisor")
chain = aims_graph.compile()

if __name__ == "__main__":

    # Define the function to enter the chain
    def enter_chain(message: str):
        results = {
            "messages": [HumanMessage(content="I'd like to take a quiz based on the uploaded notebook.")],
        }
        return results

    aims_chain = enter_chain | chain

    for s in aims_chain.stream(
        "I'd like to take a quiz based on the uploaded notebook.", {"recursion_limit": 15}
    ):
        if "__end__" not in s:
            print(s)
            print("---")