File size: 2,825 Bytes
e4f4515
 
48d9af7
0f64bae
 
e4f4515
 
 
 
 
deeba11
48d9af7
c21a510
 
 
 
 
 
 
 
 
 
 
deeba11
 
 
 
 
 
0f64bae
deeba11
 
 
 
 
 
4e8b6ab
0f64bae
deeba11
 
 
48d9af7
 
 
 
0f64bae
48d9af7
 
 
deeba11
 
 
0f64bae
48d9af7
deeba11
 
 
48d9af7
 
 
 
 
 
 
 
 
 
deeba11
48d9af7
 
 
 
 
deeba11
 
48d9af7
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
from dotenv import load_dotenv
from langgraph.graph import END, StateGraph
from states import TutorState
from agents import create_agent, agent_node, create_team_supervisor, get_retrieve_information_tool, llm, flashcard_tool
from prompt_templates import PromptTemplates
import functools

# Load environment variables
load_dotenv()

# Create the LangGraph chain
def create_tutor_chain(retrieval_chain):
    """
    Create a tutor chain for the notebook tutor system.

    This function creates a tutor chain for the notebook tutor system. The tutor chain consists of multiple agents, including a QA Agent, Quiz Agent, Flashcards Agent, and Supervisor Agent. Each agent is created with specific tools and prompts.

    Parameters:
        retrieval_chain (object): The retrieval chain used for information retrieval.

    Returns:
        StateGraph: The compiled tutor graph representing the tutor chain.
    """
    retrieve_information_tool = get_retrieve_information_tool(retrieval_chain)

    # Create QA Agent
    qa_agent = create_agent(
        llm,
        [retrieve_information_tool],
        PromptTemplates().get_qa_agent_prompt(),
    )
    qa_node = functools.partial(agent_node, agent=qa_agent, name="QAAgent")

    # Create Quiz Agent
    quiz_agent = create_agent(
        llm,
        [retrieve_information_tool],
        PromptTemplates().get_quiz_agent_prompt(),
    )
    quiz_node = functools.partial(agent_node, agent=quiz_agent, name="QuizAgent")

    # Create Flashcards Agent
    flashcards_agent = create_agent(
        llm,
        [retrieve_information_tool, flashcard_tool],
        PromptTemplates().get_flashcards_agent_prompt(),
    )
    flashcards_node = functools.partial(agent_node, agent=flashcards_agent, name="FlashcardsAgent")

    # Create Supervisor Agent
    supervisor_agent = create_team_supervisor(
        llm,
        PromptTemplates().get_supervisor_agent_prompt(),
        ["QAAgent", "QuizAgent", "FlashcardsAgent"],
    )

    # Build the LangGraph
    tutor_graph = StateGraph(TutorState)
    tutor_graph.add_node("QAAgent", qa_node)
    tutor_graph.add_node("QuizAgent", quiz_node)
    tutor_graph.add_node("FlashcardsAgent", flashcards_node)
    tutor_graph.add_node("supervisor", supervisor_agent)

    tutor_graph.add_edge("QAAgent", "supervisor")
    tutor_graph.add_edge("QuizAgent", "supervisor")
    tutor_graph.add_edge("FlashcardsAgent", "supervisor")
    tutor_graph.add_conditional_edges(
        "supervisor",
        lambda x: "FINISH" if x.get("quiz_created") or x.get("question_answered") or x.get("flashcards_created") else x["next"],
        {"QAAgent": "QAAgent",
        "QuizAgent": "QuizAgent",
        "FlashcardsAgent": "FlashcardsAgent",
        "FINISH": END},
    )

    tutor_graph.set_entry_point("supervisor")
    return tutor_graph.compile()