ERP-system / erp_core /node_builder.py
sarwarshafee8709809365's picture
deployment-1
c8e458d
from typing import Literal
from erp_core.state_definer import State
from langchain_core.messages import ToolMessage
from erp_core._event import create_tool_node_with_fallback
from erp_core.assistant_class import Assistant, CompleteOrEscalate
from erp_core.entry_node import create_entry_node
from langgraph.graph import StateGraph
from langgraph.prebuilt import tools_condition
from langgraph.graph import END, StateGraph, START
from operator import __and__
from langgraph.checkpoint.memory import MemorySaver
# from langgraph.checkpoint.sqlite import SqliteSaver
from erp_core.runnable.fm_prompt import financial_management_runnable, financial_management_tools
from erp_core.runnable.scm_prompt import supply_chain_management_runnable, supply_chain_management_tools
from erp_core.runnable.hr_prompt import human_resource_runnable, human_resource_tools
from erp_core.runnable.pm_prompt import project_management_runnable, project_management_tools
from erp_core.runnable.crm_prompt import customer_relationship_management_runnable, customer_relationship_management_tools
from erp_core.runnable.primary_assistant_prompt import assistant_runnable, primary_assistant_tools
from erp_core.tool_binder.tool_binder import ToHumanResourceDepartment, ToFinancialManagementDepartment, ToSupplyChainManagementDepartment, ToProjectManagementDepartment, ToCustomerRelationshipManagementDepartment
builder = StateGraph(State)
# fetch user info
# ........................................................................
def user_info(state: State):
return {"user_info": ""}
builder.add_node("fetch_user_info", user_info)
builder.add_edge(START, "fetch_user_info")
# financial management assistant
# ........................................................................
builder.add_node("enter_financial_management", create_entry_node("Financial Management Assistant", "financial_management"))
builder.add_node("financial_management", Assistant(financial_management_runnable))
builder.add_edge("enter_financial_management", "financial_management")
builder.add_node("financial_management_tools", create_tool_node_with_fallback(financial_management_tools))
def route_financial_management(
state: State,
) -> Literal[
"financial_management_tools",
"leave_skill",
"__end__",
]:
route = tools_condition(state)
if route == END:
return END
tool_calls = state["messages"][-1].tool_calls
did_cancel = any(tc["name"] == CompleteOrEscalate.__name__ for tc in tool_calls)
if did_cancel:
return "leave_skill"
safe_toolnames = [t.name for t in financial_management_tools]
if all(tc["name"] in safe_toolnames for tc in tool_calls):
return "financial_management_tools"
return "financial_management_tools"
builder.add_edge("financial_management_tools", "financial_management")
builder.add_conditional_edges("financial_management", route_financial_management)
# supply chain management assistant
# ........................................................................
builder.add_node("enter_supply_chain_management", create_entry_node("Supply Chain Management Assistant", "supply_chain_management"))
builder.add_node("supply_chain_management", Assistant(supply_chain_management_runnable))
builder.add_edge("enter_supply_chain_management", "supply_chain_management")
builder.add_node("supply_chain_management_tools", create_tool_node_with_fallback(supply_chain_management_tools))
def route_supply_chain_management(
state: State,
) -> Literal[
"supply_chain_management_tools",
"leave_skill",
"__end__",
]:
route = tools_condition(state)
if route == END:
return END
tool_calls = state["messages"][-1].tool_calls
did_cancel = any(tc["name"] == CompleteOrEscalate.__name__ for tc in tool_calls)
if did_cancel:
return "leave_skill"
safe_toolnames = [t.name for t in supply_chain_management_tools]
if all(tc["name"] in safe_toolnames for tc in tool_calls):
return "supply_chain_management_tools"
return "supply_chain_management_tools"
builder.add_edge("supply_chain_management_tools", "supply_chain_management")
builder.add_conditional_edges("supply_chain_management", route_supply_chain_management)
# human resource assistant
# ........................................................................
builder.add_node("enter_human_resource", create_entry_node("Human Resource Assistant", "human_resource"))
builder.add_node("human_resource", Assistant(human_resource_runnable))
builder.add_edge("enter_human_resource", "human_resource")
builder.add_node("human_resource_tools", create_tool_node_with_fallback(human_resource_tools))
def route_human_resource(
state: State,
) -> Literal[
"human_resource_tools",
"leave_skill",
"__end__",
]:
route = tools_condition(state)
if route == END:
return END # end the graph
tool_calls = state["messages"][-1].tool_calls
did_cancel = any(tc["name"] == CompleteOrEscalate.__name__ for tc in tool_calls)
if did_cancel:
return "leave_skill"
safe_toolnames = [t.name for t in human_resource_tools]
if all(tc["name"] in safe_toolnames for tc in tool_calls):
return "human_resource_tools"
return "human_resource_tools"
builder.add_edge("human_resource_tools", "human_resource")
builder.add_conditional_edges("human_resource", route_human_resource)
# Project management assistant
# ........................................................................
builder.add_node("enter_project_management", create_entry_node("Project Management Assistant", "project_management"))
builder.add_node("project_management", Assistant(project_management_runnable))
builder.add_edge("enter_project_management", "project_management")
builder.add_node("project_management_tools", create_tool_node_with_fallback(project_management_tools))
def route_project_management(
state: State,
) -> Literal[
"project_management_tools",
"leave_skill",
"__end__",
]:
route = tools_condition(state)
if route == END:
return END
tool_calls = state["messages"][-1].tool_calls
did_cancel = any(tc["name"] == CompleteOrEscalate.__name__ for tc in tool_calls)
if did_cancel:
return "leave_skill"
safe_toolnames = [t.name for t in project_management_tools]
if all(tc["name"] in safe_toolnames for tc in tool_calls):
return "project_management_tools"
return "project_management_tools"
builder.add_edge("project_management_tools", "project_management")
builder.add_conditional_edges("project_management", route_project_management)
# customer relationship management assistant
# ........................................................................
builder.add_node("enter_customer_relationship_management", create_entry_node("Customer Relationship Management Assistant", "customer_relationship_management"))
builder.add_node("customer_relationship_management", Assistant(customer_relationship_management_runnable))
builder.add_edge("enter_customer_relationship_management", "customer_relationship_management")
builder.add_node("customer_relationship_management_tools", create_tool_node_with_fallback(customer_relationship_management_tools))
def route_customer_relationship_management(
state: State,
) -> Literal[
"customer_relationship_management_tools",
"leave_skill",
"__end__",
]:
route = tools_condition(state)
if route == END:
return END
tool_calls = state["messages"][-1].tool_calls
did_cancel = any(tc["name"] == CompleteOrEscalate.__name__ for tc in tool_calls)
if did_cancel:
return "leave_skill"
safe_toolnames = [t.name for t in customer_relationship_management_tools]
if all(tc["name"] in safe_toolnames for tc in tool_calls):
return "customer_relationship_management_tools"
return "customer_relationship_management_tools"
builder.add_edge("customer_relationship_management_tools", "customer_relationship_management")
builder.add_conditional_edges("customer_relationship_management", route_customer_relationship_management)
# leave skill
# ........................................................................
def pop_dialog_state(state: State) -> dict:
"""Pop the dialog stack and return to the main assistant.
This lets the full graph explicitly track the dialog flow and delegate control
to specific sub-graphs.
"""
messages = []
if state["messages"][-1].tool_calls:
# Note: Doesn't currently handle the edge case where the llm performs parallel tool calls
messages.append(
ToolMessage(
content="Resuming dialog with the host assistant. Please reflect on the past conversation and assist the user as needed.",
tool_call_id=state["messages"][-1].tool_calls[0]["id"],
)
)
return {
"dialog_state": "pop",
"messages": messages,
}
builder.add_node("leave_skill", pop_dialog_state)
builder.add_edge("leave_skill", "primary_assistant")
# primary assistant
# ........................................................................
builder.add_node("primary_assistant", Assistant(assistant_runnable))
builder.add_node("primary_assistant_tools", create_tool_node_with_fallback(primary_assistant_tools))
def route_primary_assistant(
state: State,
) -> Literal[
"primary_assistant_tools",
"enter_human_resource",
"enter_financial_management",
"enter_supply_chain_management",
"enter_project_management",
"enter_customer_relationship_management",
"__and__",
]:
route = tools_condition(state)
if route == END:
return END
tool_calls = state["messages"][-1].tool_calls
if tool_calls:
if tool_calls[0]["name"] == ToHumanResourceDepartment.__name__:
return "enter_human_resource"
elif tool_calls[0]["name"] == ToFinancialManagementDepartment.__name__:
return "enter_financial_management"
elif tool_calls[0]["name"] == ToSupplyChainManagementDepartment.__name__:
return "enter_supply_chain_management"
elif tool_calls[0]["name"] == ToProjectManagementDepartment.__name__:
return "enter_project_management"
elif tool_calls[0]["name"] == ToCustomerRelationshipManagementDepartment.__name__:
return "enter_customer_relationship_management"
return "primary_assistant_tools"
raise ValueError("Invalid route")
# The assistant can route to one of the delegated assistants,
# directly use a tool, or directly respond to the user
builder.add_conditional_edges(
"primary_assistant",
route_primary_assistant,
{
"enter_human_resource": "enter_human_resource",
"enter_financial_management": "enter_financial_management",
"enter_supply_chain_management": "enter_supply_chain_management",
"enter_project_management": "enter_project_management",
"enter_customer_relationship_management": "enter_customer_relationship_management",
"primary_assistant_tools": "primary_assistant_tools",
END: END,
},
)
builder.add_edge("primary_assistant_tools", "primary_assistant")
# Each delegated workflow can directly respond to the user
# When the user responds, we want to return to the currently active workflow
def route_to_workflow(
state: State,
) -> Literal[
"primary_assistant",
"human_resource",
"financial_management",
"supply_chain_management",
"project_management",
"customer_relationship_management",
]:
"""If we are in a delegated state, route directly to the appropriate assistant."""
dialog_state = state.get("dialog_state")
if not dialog_state:
return "primary_assistant"
return dialog_state[-1]
builder.add_conditional_edges("fetch_user_info", route_to_workflow)
# Compile graph
def compile_graph():
memory = MemorySaver()
graph = builder.compile(checkpointer=memory)
return graph