Spaces:
Sleeping
Sleeping
File size: 7,856 Bytes
48d9af7 e4f4515 48d9af7 e4f4515 48d9af7 e4f4515 deeba11 48d9af7 ead288d e4f4515 4e8b6ab e4f4515 ead288d e4f4515 48d9af7 4e8b6ab 48d9af7 4e8b6ab 48d9af7 4e8b6ab 48d9af7 4e8b6ab e4f4515 48d9af7 ead288d e4f4515 deeba11 48d9af7 e4f4515 deeba11 4e8b6ab 48d9af7 4e8b6ab deeba11 ead288d deeba11 48d9af7 deeba11 48d9af7 deeba11 48d9af7 deeba11 48d9af7 deeba11 48d9af7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 |
from typing import Annotated
from dotenv import load_dotenv
from langchain_core.tools import tool
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.messages import AIMessage
from langchain.agents import AgentExecutor, create_openai_functions_agent
from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
from langchain_openai import ChatOpenAI
from langgraph.graph import END, StateGraph
from tools import create_flashcards_tool
from states import TutorState
import functools
# Load environment variables
load_dotenv()
# Instantiate the language model
llm = ChatOpenAI(model="gpt-4o")
class RetrievalChainWrapper:
def __init__(self, retrieval_chain):
self.retrieval_chain = retrieval_chain
def retrieve_information(
self,
query: Annotated[str, "query to ask the RAG tool"]
):
"""Use this tool to retrieve information about the provided notebook."""
response = self.retrieval_chain.invoke({"question": query})
return response["response"].content
# Create an instance of the wrapper
def get_retrieve_information_tool(retrieval_chain):
wrapper_instance = RetrievalChainWrapper(retrieval_chain)
return tool(wrapper_instance.retrieve_information)
# Instantiate the tools
flashcard_tool = create_flashcards_tool
# Function to create agents
def create_agent(
llm: ChatOpenAI,
tools: list,
system_prompt: str,
) -> AgentExecutor:
"""Create a function-calling agent and add it to the graph."""
system_prompt += "\nWork autonomously according to your specialty, using the tools available to you."
" Do not ask for clarification."
" Your other team members (and other teams) will collaborate with you with their own specialties."
" You are chosen for a reason! You are one of the following team members: {team_members}."
prompt = ChatPromptTemplate.from_messages(
[
(
"system",
system_prompt,
),
MessagesPlaceholder(variable_name="messages"),
MessagesPlaceholder(variable_name="agent_scratchpad"),
]
)
agent = create_openai_functions_agent(llm, tools, prompt)
executor = AgentExecutor(agent=agent, tools=tools, handle_parsing_errors=True)
return executor
# Function to create agent nodes
def agent_node(state, agent, name):
result = agent.invoke(state)
if 'messages' not in result:
raise ValueError(f"No messages found in agent state: {result}")
new_state = {"messages": state["messages"] + [AIMessage(content=result["output"], name=name)]}
# Set the appropriate flags and next state
if name == "QuizAgent":
new_state["quiz_created"] = True
elif name == "QAAgent":
new_state["question_answered"] = True
elif name == "FlashcardsAgent":
new_state["flashcards_created"] = True
new_state["flashcard_filename"] = result["output"].split('(')[-1].strip(')')
new_state["next"] = "FINISH"
return new_state
# Function to create the supervisor
def create_team_supervisor(llm: ChatOpenAI, system_prompt, members) -> AgentExecutor:
"""An LLM-based router."""
options = ["WAIT", "FINISH"] + members
function_def = {
"name": "route",
"description": "Select the next role.",
"parameters": {
"title": "routeSchema",
"type": "object",
"properties": {
"next": {
"title": "Next",
"anyOf": [
{"enum": options},
],
},
},
"required": ["next"],
},
}
prompt = ChatPromptTemplate.from_messages(
[
("system", system_prompt),
MessagesPlaceholder(variable_name="messages"),
(
"system",
"Given the conversation above, who should act next?"
" Or should we WAIT for user input? Select one of: {options}",
),
]
).partial(options=str(options), team_members=", ".join(members))
return (
prompt
| llm.bind_functions(functions=[function_def], function_call="route")
| JsonOutputFunctionsParser()
)
# Create the LangGraph chain
def create_tutor_chain(retrieval_chain):
retrieve_information_tool = get_retrieve_information_tool(retrieval_chain)
# Create QA Agent
qa_agent = create_agent(
llm,
[retrieve_information_tool],
"You are a QA assistant who answers questions about the provided notebook content.",
)
qa_node = functools.partial(agent_node, agent=qa_agent, name="QAAgent")
# Create Quiz Agent
quiz_agent = create_agent(
llm,
[retrieve_information_tool],
"""You are a quiz creator that generates quizzes based on the provided notebook content.
First, You MUST Use the retrieval_inforation_tool to gather context from the notebook to gather relevant and accurate information.
Next, create a 5-question quiz based on the information you have gathered. Include the answers at the end of the quiz.
Present the quiz to the user in a clear and concise manner."""
)
quiz_node = functools.partial(agent_node, agent=quiz_agent, name="QuizAgent")
# Create Flashcards Agent
flashcards_agent = create_agent(
llm,
[retrieve_information_tool, flashcard_tool],
"""
You are the Flashcard creator. Your mission is to create effective and concise flashcards based on the user's query and the content of the provided notebook. Your role involves the following tasks:
1. Analyze User Query: Understand the user's request and determine the key concepts and information they need to learn.
2. Search Notebook Content: Use the notebook content to gather relevant information and generate accurate and informative flashcards.
3. Generate Flashcards: Create a series of flashcards content with clear questions on the front and detailed answers on the back. Ensure that the flashcards cover the essential points and concepts requested by the user.
4. Export Flashcards: Use the flashcard_tool to create and export the flashcards in a format that can be easily imported into a flashcard management system, such as Anki.
Remember, your goal is to help the user learn efficiently and effectively by breaking down the notebook content into manageable, repeatable flashcards."""
)
flashcards_node = functools.partial(agent_node, agent=flashcards_agent, name="FlashcardsAgent")
# Create Supervisor Agent
supervisor_agent = create_team_supervisor(
llm,
"You are a supervisor tasked with managing a conversation between the following agents: QAAgent, QuizAgent, FlashcardsAgent. Given the user request, decide which agent should act next.",
["QAAgent", "QuizAgent", "FlashcardsAgent"],
)
# Build the LangGraph
tutor_graph = StateGraph(TutorState)
tutor_graph.add_node("QAAgent", qa_node)
tutor_graph.add_node("QuizAgent", quiz_node)
tutor_graph.add_node("FlashcardsAgent", flashcards_node)
tutor_graph.add_node("supervisor", supervisor_agent)
tutor_graph.add_edge("QAAgent", "supervisor")
tutor_graph.add_edge("QuizAgent", "supervisor")
tutor_graph.add_edge("FlashcardsAgent", "supervisor")
tutor_graph.add_conditional_edges(
"supervisor",
lambda x: "FINISH" if x.get("quiz_created") or x.get("question_answered") or x.get("flashcards_created") else x["next"],
{"QAAgent": "QAAgent",
"QuizAgent": "QuizAgent",
"FlashcardsAgent": "FlashcardsAgent",
"FINISH": END},
)
tutor_graph.set_entry_point("supervisor")
return tutor_graph.compile()
|