from dotenv import load_dotenv from langgraph.graph import END, StateGraph from states import TutorState from agents import create_agent, agent_node, create_team_supervisor, get_retrieve_information_tool, llm, flashcard_tool from prompt_templates import PromptTemplates import functools # Load environment variables load_dotenv() # Create the LangGraph chain def create_tutor_chain(retrieval_chain): """ Create a tutor chain for the notebook tutor system. This function creates a tutor chain for the notebook tutor system. The tutor chain consists of multiple agents, including a QA Agent, Quiz Agent, Flashcards Agent, and Supervisor Agent. Each agent is created with specific tools and prompts. Parameters: retrieval_chain (object): The retrieval chain used for information retrieval. Returns: StateGraph: The compiled tutor graph representing the tutor chain. """ retrieve_information_tool = get_retrieve_information_tool(retrieval_chain) # Create QA Agent qa_agent = create_agent( llm, [retrieve_information_tool], PromptTemplates().get_qa_agent_prompt(), ) qa_node = functools.partial(agent_node, agent=qa_agent, name="QAAgent") # Create Quiz Agent quiz_agent = create_agent( llm, [retrieve_information_tool], PromptTemplates().get_quiz_agent_prompt(), ) quiz_node = functools.partial(agent_node, agent=quiz_agent, name="QuizAgent") # Create Flashcards Agent flashcards_agent = create_agent( llm, [retrieve_information_tool, flashcard_tool], PromptTemplates().get_flashcards_agent_prompt(), ) flashcards_node = functools.partial(agent_node, agent=flashcards_agent, name="FlashcardsAgent") # Create Supervisor Agent supervisor_agent = create_team_supervisor( llm, PromptTemplates().get_supervisor_agent_prompt(), ["QAAgent", "QuizAgent", "FlashcardsAgent"], ) # Build the LangGraph tutor_graph = StateGraph(TutorState) tutor_graph.add_node("QAAgent", qa_node) tutor_graph.add_node("QuizAgent", quiz_node) tutor_graph.add_node("FlashcardsAgent", flashcards_node) tutor_graph.add_node("supervisor", supervisor_agent) tutor_graph.add_edge("QAAgent", "supervisor") tutor_graph.add_edge("QuizAgent", "supervisor") tutor_graph.add_edge("FlashcardsAgent", "supervisor") tutor_graph.add_conditional_edges( "supervisor", lambda x: "FINISH" if x.get("quiz_created") or x.get("question_answered") or x.get("flashcards_created") else x["next"], {"QAAgent": "QAAgent", "QuizAgent": "QuizAgent", "FlashcardsAgent": "FlashcardsAgent", "FINISH": END}, ) tutor_graph.set_entry_point("supervisor") return tutor_graph.compile()