File size: 3,830 Bytes
963e565
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import os
import json
import operator
import streamlit as st
import tempfile
from typing import TypedDict, Annotated, Sequence
from langchain_openai import ChatOpenAI
from langchain_core.tools import tool
from langchain_core.messages import HumanMessage, AIMessage
from langchain_core.utils.function_calling import convert_to_openai_tool
from langgraph.graph import StateGraph, END

# Environment Setup
os.environ['OPENAI_API_KEY'] = os.getenv("OPENAI_API_KEY")

# Model Initialization
model = ChatOpenAI(temperature=0)

# Define the tool
@tool
def multiply(first_number: int, second_number: int):
    """Multiplies two numbers together and returns the result."""
    return first_number * second_number

# Bind tool to model
model_with_tools = model.bind(tools=[convert_to_openai_tool(multiply)])

# State Setup
class AgentState(TypedDict):
    messages: Annotated[Sequence, operator.add]

graph = StateGraph(AgentState)

# Model Invocation
def invoke_model(state):
    question = state['messages'][-1].content
    response = model_with_tools.invoke(question)
    return {"messages": [response]}

graph.add_node("agent", invoke_model)

# Tool Invocation
def invoke_tool(state):
    tool_calls = state['messages'][-1].additional_kwargs.get("tool_calls", [])
    tool_results = []

    for tool_call in tool_calls:
        if tool_call["function"]["name"] == "multiply":
            arguments = json.loads(tool_call["function"]["arguments"])
            result = multiply.invoke(arguments)
            tool_results.append(
                AIMessage(content=f"Tool Result: {result}", additional_kwargs={"tool_call_id": tool_call["id"]})
            )
    return {"messages": tool_results or [AIMessage(content="No valid tool input provided.")]}

graph.add_node("tool", invoke_tool)

# Explicit Router Node
def router(state):
    """
    Route to tool if tool calls exist; otherwise END the workflow.
    """
    tool_calls = state['messages'][-1].additional_kwargs.get("tool_calls", [])
    if tool_calls:
        return {"messages": [AIMessage(content="Routing to tool...")]}
    else:
        return {"messages": [AIMessage(content=state['messages'][-1].content)]}

graph.add_node("router", router)

# Add explicit edges
graph.add_edge("agent", "router")
graph.add_conditional_edges("router", lambda state: "tool" if state['messages'][-1].additional_kwargs.get("tool_calls") else END, {"tool": "tool", END: END})
graph.add_edge("tool", END)

# Compile the graph
graph.set_entry_point("agent")
app_graph = graph.compile()

# Save graph visualization with xray for visibility
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmpfile:
    graph_viz = app_graph.get_graph(xray=True)
    tmpfile.write(graph_viz.draw_mermaid_png())
    graph_image_path = tmpfile.name

# Streamlit Interface
st.title("Simple Tool Calling Demo")
st.image(graph_image_path, caption="Workflow Visualization")

tab1, tab2 = st.tabs(["Try Multiplication", "Ask General Queries"])

with tab1:
    st.subheader("Try Multiplication")
    first_number = st.number_input("First Number", value=0, step=1)
    second_number = st.number_input("Second Number", value=0, step=1)

    if st.button("Multiply"):
        question = f"What is {first_number} * {second_number}?"
        try:
            output = app_graph.invoke({"messages": [HumanMessage(content=question)]})
            st.success(output['messages'][-1].content)
        except Exception as e:
            st.error(f"Error: {e}")

with tab2:
    st.subheader("General Query")
    user_input = st.text_input("Enter your question here")

    if st.button("Submit"):
        try:
            result = app_graph.invoke({"messages": [HumanMessage(content=user_input)]})
            st.success(result['messages'][-1].content)
        except Exception as e:
            st.error(f"Error: {e}")