JulsdL's picture
Code documentation and dockerizing AI Notebook Tutor
c21a510
from dotenv import load_dotenv
from langgraph.graph import END, StateGraph
from states import TutorState
from agents import create_agent, agent_node, create_team_supervisor, get_retrieve_information_tool, llm, flashcard_tool
from prompt_templates import PromptTemplates
import functools
# Load environment variables
load_dotenv()
# Create the LangGraph chain
def create_tutor_chain(retrieval_chain):
"""
Create a tutor chain for the notebook tutor system.
This function creates a tutor chain for the notebook tutor system. The tutor chain consists of multiple agents, including a QA Agent, Quiz Agent, Flashcards Agent, and Supervisor Agent. Each agent is created with specific tools and prompts.
Parameters:
retrieval_chain (object): The retrieval chain used for information retrieval.
Returns:
StateGraph: The compiled tutor graph representing the tutor chain.
"""
retrieve_information_tool = get_retrieve_information_tool(retrieval_chain)
# Create QA Agent
qa_agent = create_agent(
llm,
[retrieve_information_tool],
PromptTemplates().get_qa_agent_prompt(),
)
qa_node = functools.partial(agent_node, agent=qa_agent, name="QAAgent")
# Create Quiz Agent
quiz_agent = create_agent(
llm,
[retrieve_information_tool],
PromptTemplates().get_quiz_agent_prompt(),
)
quiz_node = functools.partial(agent_node, agent=quiz_agent, name="QuizAgent")
# Create Flashcards Agent
flashcards_agent = create_agent(
llm,
[retrieve_information_tool, flashcard_tool],
PromptTemplates().get_flashcards_agent_prompt(),
)
flashcards_node = functools.partial(agent_node, agent=flashcards_agent, name="FlashcardsAgent")
# Create Supervisor Agent
supervisor_agent = create_team_supervisor(
llm,
PromptTemplates().get_supervisor_agent_prompt(),
["QAAgent", "QuizAgent", "FlashcardsAgent"],
)
# Build the LangGraph
tutor_graph = StateGraph(TutorState)
tutor_graph.add_node("QAAgent", qa_node)
tutor_graph.add_node("QuizAgent", quiz_node)
tutor_graph.add_node("FlashcardsAgent", flashcards_node)
tutor_graph.add_node("supervisor", supervisor_agent)
tutor_graph.add_edge("QAAgent", "supervisor")
tutor_graph.add_edge("QuizAgent", "supervisor")
tutor_graph.add_edge("FlashcardsAgent", "supervisor")
tutor_graph.add_conditional_edges(
"supervisor",
lambda x: "FINISH" if x.get("quiz_created") or x.get("question_answered") or x.get("flashcards_created") else x["next"],
{"QAAgent": "QAAgent",
"QuizAgent": "QuizAgent",
"FlashcardsAgent": "FlashcardsAgent",
"FINISH": END},
)
tutor_graph.set_entry_point("supervisor")
return tutor_graph.compile()