File size: 4,026 Bytes
54343c0
 
 
e19b96d
012fbad
54343c0
 
 
569f7b7
54343c0
1d4c700
54343c0
e19b96d
 
54343c0
e19b96d
54343c0
 
beab746
54343c0
 
e19b96d
54343c0
 
beab746
54343c0
 
e19b96d
1d4c700
e19b96d
293631b
1d4c700
54343c0
beab746
1d4c700
d3ffdbc
20d5bf2
beab746
54343c0
1d4c700
 
beab746
1d4c700
1a67ad9
d3ffdbc
 
1d4c700
d3ffdbc
 
beab746
d3ffdbc
 
 
 
bced22e
1d4c700
df88f89
d3ffdbc
1d4c700
569f7b7
d3ffdbc
569f7b7
beab746
d3ffdbc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e19b96d
d3ffdbc
 
e19b96d
 
d3ffdbc
e19b96d
569f7b7
beab746
e19b96d
 
 
 
 
 
 
 
 
 
d3ffdbc
 
e19b96d
 
 
beab746
 
 
 
 
e19b96d
 
 
 
 
 
d3ffdbc
 
 
 
 
e19b96d
beab746
4a26df1
569f7b7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import os
import json
import operator
import streamlit as st
import tempfile
from typing import TypedDict, Annotated, Sequence
from langchain_openai import ChatOpenAI
from langchain_core.tools import tool
from langchain_core.messages import HumanMessage, AIMessage
from langchain_core.utils.function_calling import convert_to_openai_tool
from langgraph.graph import StateGraph, END

# Environment Setup
os.environ['OPENAI_API_KEY'] = os.getenv("OPENAI_API_KEY")

# Model Initialization
model = ChatOpenAI(temperature=0)

# Define the tool
@tool
def multiply(first_number: int, second_number: int):
    """Multiplies two numbers together and returns the result."""
    return first_number * second_number

# Bind tool to model
model_with_tools = model.bind(tools=[convert_to_openai_tool(multiply)])

# State Setup
class AgentState(TypedDict):
    messages: Annotated[Sequence, operator.add]

graph = StateGraph(AgentState)

# Model Invocation
def invoke_model(state):
    question = state['messages'][-1].content
    response = model_with_tools.invoke(question)
    return {"messages": [response]}

graph.add_node("agent", invoke_model)

# Tool Invocation
def invoke_tool(state):
    tool_calls = state['messages'][-1].additional_kwargs.get("tool_calls", [])
    tool_results = []

    for tool_call in tool_calls:
        if tool_call["function"]["name"] == "multiply":
            arguments = json.loads(tool_call["function"]["arguments"])
            result = multiply.invoke(arguments)
            tool_results.append(
                AIMessage(content=f"Tool Result: {result}", additional_kwargs={"tool_call_id": tool_call["id"]})
            )
    return {"messages": tool_results or [AIMessage(content="No valid tool input provided.")]}

graph.add_node("tool", invoke_tool)

# Fixed Router Node
def router(state):
    """
    Route to tool if tool calls exist; otherwise END the workflow.
    """
    tool_calls = state['messages'][-1].additional_kwargs.get("tool_calls", [])
    if tool_calls:
        return {"__state__": "tool"}  # Direct routing to tool
    else:
        return {"__state__": END}  # End workflow if no tool calls

graph.add_node("router", router)

# Add explicit edges
graph.add_edge("agent", "router")
graph.add_conditional_edges(
    "router",
    lambda state: "tool" if state['messages'][-1].additional_kwargs.get("tool_calls") else END,
    {"tool": "tool", END: END}
)
graph.add_edge("tool", END)

# Compile the graph
graph.set_entry_point("agent")
app_graph = graph.compile()

# Save graph visualization with xray for visibility
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmpfile:
    graph_viz = app_graph.get_graph(xray=True)
    tmpfile.write(graph_viz.draw_mermaid_png())
    graph_image_path = tmpfile.name

# Streamlit Interface
st.title("Simple Tool Calling Demo")
st.image(graph_image_path, caption="Workflow Visualization")

tab1, tab2 = st.tabs(["Try Multiplication", "Ask General Queries"])

with tab1:
    st.subheader("Try Multiplication")
    first_number = st.number_input("First Number", value=0, step=1)
    second_number = st.number_input("Second Number", value=0, step=1)

    if st.button("Multiply"):
        question = f"What is {first_number} * {second_number}?"
        try:
            output = app_graph.invoke({"messages": [HumanMessage(content=question)]})
            st.success(output['messages'][-1].content)
        except Exception as e:
            st.error(f"Error: {e}")

with tab2:
    st.subheader("General Query")
    user_input = st.text_input("Enter your question here")

    if st.button("Submit"):
        try:
            result = app_graph.invoke({"messages": [HumanMessage(content=user_input)]})
            st.success(result['messages'][-1].content)
        except Exception as e:
            st.error(f"Error: {e}")

# Sidebar for References
st.sidebar.title("Reference:")
st.sidebar.markdown("1. [LangGraph Tool Calling](https://github.com/aritrasen87/LLM_RAG_Model_Deployment/blob/main/LangGraph_02_ToolCalling.ipynb)")