Spaces:
Sleeping
Sleeping
from typing import Literal | |
from erp_core.state_definer import State | |
from langchain_core.messages import ToolMessage | |
from erp_core._event import create_tool_node_with_fallback | |
from erp_core.assistant_class import Assistant, CompleteOrEscalate | |
from erp_core.entry_node import create_entry_node | |
from langgraph.graph import StateGraph | |
from langgraph.prebuilt import tools_condition | |
from langgraph.graph import END, StateGraph, START | |
from operator import __and__ | |
from langgraph.checkpoint.memory import MemorySaver | |
# from langgraph.checkpoint.sqlite import SqliteSaver | |
from erp_core.runnable.fm_prompt import financial_management_runnable, financial_management_tools | |
from erp_core.runnable.scm_prompt import supply_chain_management_runnable, supply_chain_management_tools | |
from erp_core.runnable.hr_prompt import human_resource_runnable, human_resource_tools | |
from erp_core.runnable.pm_prompt import project_management_runnable, project_management_tools | |
from erp_core.runnable.crm_prompt import customer_relationship_management_runnable, customer_relationship_management_tools | |
from erp_core.runnable.primary_assistant_prompt import assistant_runnable, primary_assistant_tools | |
from erp_core.tool_binder.tool_binder import ToHumanResourceDepartment, ToFinancialManagementDepartment, ToSupplyChainManagementDepartment, ToProjectManagementDepartment, ToCustomerRelationshipManagementDepartment | |
builder = StateGraph(State) | |
# fetch user info | |
# ........................................................................ | |
def user_info(state: State): | |
return {"user_info": ""} | |
builder.add_node("fetch_user_info", user_info) | |
builder.add_edge(START, "fetch_user_info") | |
# financial management assistant | |
# ........................................................................ | |
builder.add_node("enter_financial_management", create_entry_node("Financial Management Assistant", "financial_management")) | |
builder.add_node("financial_management", Assistant(financial_management_runnable)) | |
builder.add_edge("enter_financial_management", "financial_management") | |
builder.add_node("financial_management_tools", create_tool_node_with_fallback(financial_management_tools)) | |
def route_financial_management( | |
state: State, | |
) -> Literal[ | |
"financial_management_tools", | |
"leave_skill", | |
"__end__", | |
]: | |
route = tools_condition(state) | |
if route == END: | |
return END | |
tool_calls = state["messages"][-1].tool_calls | |
did_cancel = any(tc["name"] == CompleteOrEscalate.__name__ for tc in tool_calls) | |
if did_cancel: | |
return "leave_skill" | |
safe_toolnames = [t.name for t in financial_management_tools] | |
if all(tc["name"] in safe_toolnames for tc in tool_calls): | |
return "financial_management_tools" | |
return "financial_management_tools" | |
builder.add_edge("financial_management_tools", "financial_management") | |
builder.add_conditional_edges("financial_management", route_financial_management) | |
# supply chain management assistant | |
# ........................................................................ | |
builder.add_node("enter_supply_chain_management", create_entry_node("Supply Chain Management Assistant", "supply_chain_management")) | |
builder.add_node("supply_chain_management", Assistant(supply_chain_management_runnable)) | |
builder.add_edge("enter_supply_chain_management", "supply_chain_management") | |
builder.add_node("supply_chain_management_tools", create_tool_node_with_fallback(supply_chain_management_tools)) | |
def route_supply_chain_management( | |
state: State, | |
) -> Literal[ | |
"supply_chain_management_tools", | |
"leave_skill", | |
"__end__", | |
]: | |
route = tools_condition(state) | |
if route == END: | |
return END | |
tool_calls = state["messages"][-1].tool_calls | |
did_cancel = any(tc["name"] == CompleteOrEscalate.__name__ for tc in tool_calls) | |
if did_cancel: | |
return "leave_skill" | |
safe_toolnames = [t.name for t in supply_chain_management_tools] | |
if all(tc["name"] in safe_toolnames for tc in tool_calls): | |
return "supply_chain_management_tools" | |
return "supply_chain_management_tools" | |
builder.add_edge("supply_chain_management_tools", "supply_chain_management") | |
builder.add_conditional_edges("supply_chain_management", route_supply_chain_management) | |
# human resource assistant | |
# ........................................................................ | |
builder.add_node("enter_human_resource", create_entry_node("Human Resource Assistant", "human_resource")) | |
builder.add_node("human_resource", Assistant(human_resource_runnable)) | |
builder.add_edge("enter_human_resource", "human_resource") | |
builder.add_node("human_resource_tools", create_tool_node_with_fallback(human_resource_tools)) | |
def route_human_resource( | |
state: State, | |
) -> Literal[ | |
"human_resource_tools", | |
"leave_skill", | |
"__end__", | |
]: | |
route = tools_condition(state) | |
if route == END: | |
return END # end the graph | |
tool_calls = state["messages"][-1].tool_calls | |
did_cancel = any(tc["name"] == CompleteOrEscalate.__name__ for tc in tool_calls) | |
if did_cancel: | |
return "leave_skill" | |
safe_toolnames = [t.name for t in human_resource_tools] | |
if all(tc["name"] in safe_toolnames for tc in tool_calls): | |
return "human_resource_tools" | |
return "human_resource_tools" | |
builder.add_edge("human_resource_tools", "human_resource") | |
builder.add_conditional_edges("human_resource", route_human_resource) | |
# Project management assistant | |
# ........................................................................ | |
builder.add_node("enter_project_management", create_entry_node("Project Management Assistant", "project_management")) | |
builder.add_node("project_management", Assistant(project_management_runnable)) | |
builder.add_edge("enter_project_management", "project_management") | |
builder.add_node("project_management_tools", create_tool_node_with_fallback(project_management_tools)) | |
def route_project_management( | |
state: State, | |
) -> Literal[ | |
"project_management_tools", | |
"leave_skill", | |
"__end__", | |
]: | |
route = tools_condition(state) | |
if route == END: | |
return END | |
tool_calls = state["messages"][-1].tool_calls | |
did_cancel = any(tc["name"] == CompleteOrEscalate.__name__ for tc in tool_calls) | |
if did_cancel: | |
return "leave_skill" | |
safe_toolnames = [t.name for t in project_management_tools] | |
if all(tc["name"] in safe_toolnames for tc in tool_calls): | |
return "project_management_tools" | |
return "project_management_tools" | |
builder.add_edge("project_management_tools", "project_management") | |
builder.add_conditional_edges("project_management", route_project_management) | |
# customer relationship management assistant | |
# ........................................................................ | |
builder.add_node("enter_customer_relationship_management", create_entry_node("Customer Relationship Management Assistant", "customer_relationship_management")) | |
builder.add_node("customer_relationship_management", Assistant(customer_relationship_management_runnable)) | |
builder.add_edge("enter_customer_relationship_management", "customer_relationship_management") | |
builder.add_node("customer_relationship_management_tools", create_tool_node_with_fallback(customer_relationship_management_tools)) | |
def route_customer_relationship_management( | |
state: State, | |
) -> Literal[ | |
"customer_relationship_management_tools", | |
"leave_skill", | |
"__end__", | |
]: | |
route = tools_condition(state) | |
if route == END: | |
return END | |
tool_calls = state["messages"][-1].tool_calls | |
did_cancel = any(tc["name"] == CompleteOrEscalate.__name__ for tc in tool_calls) | |
if did_cancel: | |
return "leave_skill" | |
safe_toolnames = [t.name for t in customer_relationship_management_tools] | |
if all(tc["name"] in safe_toolnames for tc in tool_calls): | |
return "customer_relationship_management_tools" | |
return "customer_relationship_management_tools" | |
builder.add_edge("customer_relationship_management_tools", "customer_relationship_management") | |
builder.add_conditional_edges("customer_relationship_management", route_customer_relationship_management) | |
# leave skill | |
# ........................................................................ | |
def pop_dialog_state(state: State) -> dict: | |
"""Pop the dialog stack and return to the main assistant. | |
This lets the full graph explicitly track the dialog flow and delegate control | |
to specific sub-graphs. | |
""" | |
messages = [] | |
if state["messages"][-1].tool_calls: | |
# Note: Doesn't currently handle the edge case where the llm performs parallel tool calls | |
messages.append( | |
ToolMessage( | |
content="Resuming dialog with the host assistant. Please reflect on the past conversation and assist the user as needed.", | |
tool_call_id=state["messages"][-1].tool_calls[0]["id"], | |
) | |
) | |
return { | |
"dialog_state": "pop", | |
"messages": messages, | |
} | |
builder.add_node("leave_skill", pop_dialog_state) | |
builder.add_edge("leave_skill", "primary_assistant") | |
# primary assistant | |
# ........................................................................ | |
builder.add_node("primary_assistant", Assistant(assistant_runnable)) | |
builder.add_node("primary_assistant_tools", create_tool_node_with_fallback(primary_assistant_tools)) | |
def route_primary_assistant( | |
state: State, | |
) -> Literal[ | |
"primary_assistant_tools", | |
"enter_human_resource", | |
"enter_financial_management", | |
"enter_supply_chain_management", | |
"enter_project_management", | |
"enter_customer_relationship_management", | |
"__and__", | |
]: | |
route = tools_condition(state) | |
if route == END: | |
return END | |
tool_calls = state["messages"][-1].tool_calls | |
if tool_calls: | |
if tool_calls[0]["name"] == ToHumanResourceDepartment.__name__: | |
return "enter_human_resource" | |
elif tool_calls[0]["name"] == ToFinancialManagementDepartment.__name__: | |
return "enter_financial_management" | |
elif tool_calls[0]["name"] == ToSupplyChainManagementDepartment.__name__: | |
return "enter_supply_chain_management" | |
elif tool_calls[0]["name"] == ToProjectManagementDepartment.__name__: | |
return "enter_project_management" | |
elif tool_calls[0]["name"] == ToCustomerRelationshipManagementDepartment.__name__: | |
return "enter_customer_relationship_management" | |
return "primary_assistant_tools" | |
raise ValueError("Invalid route") | |
# The assistant can route to one of the delegated assistants, | |
# directly use a tool, or directly respond to the user | |
builder.add_conditional_edges( | |
"primary_assistant", | |
route_primary_assistant, | |
{ | |
"enter_human_resource": "enter_human_resource", | |
"enter_financial_management": "enter_financial_management", | |
"enter_supply_chain_management": "enter_supply_chain_management", | |
"enter_project_management": "enter_project_management", | |
"enter_customer_relationship_management": "enter_customer_relationship_management", | |
"primary_assistant_tools": "primary_assistant_tools", | |
END: END, | |
}, | |
) | |
builder.add_edge("primary_assistant_tools", "primary_assistant") | |
# Each delegated workflow can directly respond to the user | |
# When the user responds, we want to return to the currently active workflow | |
def route_to_workflow( | |
state: State, | |
) -> Literal[ | |
"primary_assistant", | |
"human_resource", | |
"financial_management", | |
"supply_chain_management", | |
"project_management", | |
"customer_relationship_management", | |
]: | |
"""If we are in a delegated state, route directly to the appropriate assistant.""" | |
dialog_state = state.get("dialog_state") | |
if not dialog_state: | |
return "primary_assistant" | |
return dialog_state[-1] | |
builder.add_conditional_edges("fetch_user_info", route_to_workflow) | |
# Compile graph | |
def compile_graph(): | |
memory = MemorySaver() | |
graph = builder.compile(checkpointer=memory) | |
return graph | |