File size: 3,401 Bytes
76c5345
a0503bc
76c5345
224ff63
 
 
 
 
 
 
8b9c87b
224ff63
 
 
8b9c87b
796ceef
 
 
 
 
 
224ff63
 
 
 
 
796ceef
 
224ff63
796ceef
 
 
 
 
 
 
224ff63
 
 
796ceef
 
224ff63
 
 
 
 
 
 
 
 
 
 
 
 
796ceef
 
 
224ff63
 
 
796ceef
224ff63
dbf2f6d
224ff63
 
 
 
796ceef
dbf2f6d
224ff63
 
 
 
 
 
 
 
 
 
796ceef
dbf2f6d
224ff63
 
 
8b9c87b
33b89d4
224ff63
8b9c87b
224ff63
76c5345
224ff63
 
 
 
dbf2f6d
224ff63
 
 
 
 
a93b719
 
 
 
 
796ceef
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106

def agent(payload):

    DEBUG=True

    from agent.memory import Memory
    memory = Memory(payload)

    from agent.jsonencoder import json_parse_chain
    from agent.agent_main import Chain_Main_Agent

    chain_main_agent = Chain_Main_Agent(memory)

    
    from agent.toolset import tool_executor, converted_tools

    from langgraph.prebuilt import ToolInvocation
    import json
    from langchain_core.messages import FunctionMessage


    def call_main_agent(messages):   
        response = chain_main_agent.invoke({"conversation":messages, "thread_id": memory.thread_id})
        
        if DEBUG: print("call_main_agent called");                                        

        return response                                                 

    def use_tool(messages):    
        last_message = messages[-1]                             
        action = ToolInvocation(
            tool=last_message.additional_kwargs["function_call"]["name"],
            tool_input=json.loads(last_message.additional_kwargs["function_call"]["arguments"]),
        )
        response = tool_executor.invoke(action)
        function_message = FunctionMessage(content=str(response), name=action.tool)

        if DEBUG: print("Suggesting Tool to use..."+action.tool);                                                                                

        return function_message

    def render_output(messages):

        import json

        response = json_parse_chain.invoke({"conversation":messages, "thread_id": memory.thread_id})

        if DEBUG: print("Rendering output");   

        from langchain_core.messages import AIMessage                                     

        response = json.dumps(response)
        return AIMessage(content=response)

    from langgraph.graph import MessageGraph, END
    workflow = MessageGraph()

    workflow.add_node("main_agent", call_main_agent)
    workflow.add_node("use_tool", use_tool)
    workflow.add_node("render_output", render_output)

    workflow.set_entry_point("main_agent")

    def should_continue(messages):                                        
        last_message = messages[-1]
        if "function_call" not in last_message.additional_kwargs: return "render_output"
        else: return "continue"


    workflow.add_conditional_edges(
        "main_agent", should_continue, 
        {
            "continue": "use_tool", 
            "render_output":"render_output",
            "end": END
        }
    )
    workflow.add_edge('use_tool', 'main_agent')
    workflow.add_edge('render_output', END)


    app = workflow.compile(checkpointer=memory.checkpoints)
    
    from langchain_core.messages import HumanMessage

    input = payload.get("input") or "Can I earn credit?"
    inputs = [HumanMessage(content=input)]

    response = app.invoke(inputs, {"configurable": {"thread_id": memory.thread_id} } )

    '''
    inputs = [HumanMessage(content="My name is Mark")]
    response = app.invoke(inputs, {"configurable": {"thread_id": memory.thread_id} } )
    print(response[-1].content)

    inputs = [HumanMessage(content="What is my name?")]
    response = app.invoke(inputs, {"configurable": {"thread_id": memory.thread_id} } )
    print(response[-1].content)
    '''
    
    response = response[-1].content[:-1] + ', "thread_id": "' + str(memory.thread_id) + '"}'
    

    print(response);
    return response