File size: 4,533 Bytes
54343c0
 
 
e19b96d
012fbad
54343c0
 
 
0a642a3
54343c0
1d4c700
54343c0
e19b96d
 
54343c0
e19b96d
54343c0
 
beab746
54343c0
 
e19b96d
54343c0
 
beab746
54343c0
 
e19b96d
1d4c700
e19b96d
293631b
1d4c700
54343c0
beab746
1d4c700
beab746
 
 
0a642a3
20d5bf2
beab746
0a642a3
beab746
54343c0
1d4c700
 
beab746
1d4c700
beab746
 
 
1a67ad9
1d4c700
0a642a3
141b6b8
0a642a3
beab746
141b6b8
 
0a642a3
df88f89
141b6b8
1d4c700
df88f89
0a642a3
1d4c700
beab746
0a642a3
e19b96d
0a642a3
 
 
 
 
 
 
 
 
e19b96d
 
0a642a3
e19b96d
0a642a3
beab746
e19b96d
 
 
 
 
 
 
 
0a642a3
e19b96d
 
beab746
e19b96d
 
 
 
 
 
 
 
 
 
 
beab746
 
 
 
 
e19b96d
beab746
e19b96d
 
 
 
 
 
 
beab746
e19b96d
beab746
e19b96d
beab746
e19b96d
 
 
beab746
e19b96d
beab746
0a642a3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import os
import json
import operator
import streamlit as st
import tempfile
from typing import TypedDict, Annotated, Sequence
from langchain_openai import ChatOpenAI
from langchain_core.tools import tool
from langchain_core.messages import HumanMessage, ToolMessage
from langchain_core.utils.function_calling import convert_to_openai_tool
from langgraph.graph import StateGraph, END

# Environment Setup
os.environ['OPENAI_API_KEY'] = os.getenv("OPENAI_API_KEY")

# Model Initialization
model = ChatOpenAI(temperature=0)

# Define the tool
@tool
def multiply(first_number: int, second_number: int):
    """Multiplies two numbers together and returns the result."""
    return first_number * second_number

# Bind tool to model
model_with_tools = model.bind(tools=[convert_to_openai_tool(multiply)])

# State Setup
class AgentState(TypedDict):
    messages: Annotated[Sequence, operator.add]

graph = StateGraph(AgentState)

# Model Invocation
def invoke_model(state):
    """
    Invoke the model and handle tool invocation logic.
    """
    question = state['messages'][-1].content
    response = model_with_tools.invoke(question)

    # Return the model's response with tool_calls, if any
    return {"messages": [response]}

graph.add_node("agent", invoke_model)

# Tool Invocation
def invoke_tool(state):
    """
    Invoke the 'multiply' tool if it's called by the model.
    """
    tool_calls = state['messages'][-1].additional_kwargs.get("tool_calls", [])
    for tool_call in tool_calls:
        if tool_call["function"]["name"] == "multiply":
            # Extract and parse the arguments
            arguments = json.loads(tool_call["function"]["arguments"])
            result = multiply.invoke(arguments)
            # Return the tool result while maintaining tool_call_id for future use
            return {"messages": [ToolMessage(content=str(result), name="multiply", tool_call_id=tool_call["id"])]}
    return {"messages": [ToolMessage(content="No valid tool input provided.")]}


graph.add_node("tool", invoke_tool)

# Router Node: Manual Addition
def router(state):
    tool_calls = state['messages'][-1].additional_kwargs.get("tool_calls", [])
    return {"messages": [ToolMessage(content="Routing...")]}  # Dummy message for router node

graph.add_node("router", router)

# Add explicit edges from agent to router, and router to tool/END
graph.add_edge("agent", "router")
graph.add_conditional_edges("router", lambda state: "tool" if state['messages'][-1].additional_kwargs.get("tool_calls") else END, {"tool": "tool", END: END})
graph.add_edge("tool", END)

# Compile the graph
graph.set_entry_point("agent")
app_graph = graph.compile()

# Save graph visualization with router explicitly included
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmpfile:
    graph_viz = app_graph.get_graph(xray=True)  # Ensures detailed visualization
    tmpfile.write(graph_viz.draw_mermaid_png())
    graph_image_path = tmpfile.name

# Streamlit Interface
st.title("Simple Tool Calling Demo")

# Display the workflow graph
st.image(graph_image_path, caption="Workflow Visualization")

# Tabbed Interface
tab1, tab2 = st.tabs(["Try Multiplication", "Ask General Queries"])

# Multiplication Tool Tab
with tab1:
    st.subheader("Try Multiplication")
    col1, col2 = st.columns(2)

    with col1:
        first_number = st.number_input("First Number", value=0, step=1)
    with col2:
        second_number = st.number_input("Second Number", value=0, step=1)

    if st.button("Multiply"):
        question = f"What is {first_number} * {second_number}?"
        try:
            output = app_graph.invoke({"messages": [HumanMessage(content=question)]})
            st.success(output['messages'][-1].content)
        except Exception as e:
            st.error(f"Error: {e}")

# General Query Tab
with tab2:
    st.subheader("General Query")
    user_input = st.text_input("Enter your question here")

    if st.button("Submit"):
        if user_input:
            try:
                result = app_graph.invoke({"messages": [HumanMessage(content=user_input)]})
                st.write("Response:")
                st.success(result['messages'][-1].content)
            except Exception as e:
                st.error(f"Error: {e}")
        else:
            st.warning("Please enter a valid input.")

# Sidebar for References
st.sidebar.title("References")
st.sidebar.markdown("1. [LangGraph Tool Calling](https://github.com/aritrasen87/LLM_RAG_Model_Deployment/blob/main/LangGraph_02_ToolCalling.ipynb)")