File size: 4,554 Bytes
7054888
 
 
 
737e35e
7054888
 
 
a5c39c6
7054888
 
 
 
 
 
 
 
 
a5c39c6
7054888
 
 
 
 
a5c39c6
7054888
 
 
 
 
 
 
 
a5c39c6
7054888
a5c39c6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7054888
 
 
a5c39c6
7054888
a5c39c6
 
 
7054888
 
a5c39c6
 
 
 
 
7054888
 
 
 
 
a5c39c6
7054888
a5c39c6
 
 
 
 
7054888
a5c39c6
7054888
 
737e35e
 
 
a5c39c6
737e35e
 
7054888
737e35e
 
 
 
7054888
737e35e
7054888
a5c39c6
7054888
 
 
 
 
 
 
 
 
 
 
a5c39c6
 
 
 
 
7054888
a5c39c6
7054888
 
 
 
 
 
 
a5c39c6
 
7054888
a5c39c6
7054888
a5c39c6
7054888
737e35e
 
a5c39c6
737e35e
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import os
import json
import operator
import streamlit as st
import tempfile
from typing import TypedDict, Annotated, Sequence
from langchain_openai import ChatOpenAI
from langchain_core.tools import tool
from langchain_core.messages import HumanMessage, AIMessage
from langchain_core.utils.function_calling import convert_to_openai_tool
from langgraph.graph import StateGraph, END

# Environment Setup
os.environ['OPENAI_API_KEY'] = os.getenv("OPENAI_API_KEY")

# Model Initialization
model = ChatOpenAI(temperature=0)

# Define the tool
@tool
def multiply(first_number: int, second_number: int):
    """Multiplies two numbers together and returns the result."""
    return first_number * second_number

# Bind tool to model
model_with_tools = model.bind(tools=[convert_to_openai_tool(multiply)])

# State Setup
class AgentState(TypedDict):
    messages: Annotated[Sequence, operator.add]

graph = StateGraph(AgentState)

# Model Invocation
def invoke_model(state):
    """
    Invoke the model and handle tool invocation logic.
    """
    # Extract the question as a string
    question = state['messages'][-1].content if isinstance(state['messages'][-1], HumanMessage) else state['messages'][-1]
    response = model_with_tools.invoke(question)

    # If the response is plain text (no tool calls)
    if isinstance(response, str):
        return {"messages": [AIMessage(content=response)]}

    # If no tool calls exist
    if not response.additional_kwargs.get("tool_calls", []):
        return {"messages": [AIMessage(content=response.content)]}

    # If tool calls are present, return the full response
    return {"messages": [response]}

graph.add_node("agent", invoke_model)

# Tool Invocation
def invoke_tool(state):
    """
    Invoke the 'multiply' tool if it's called by the model.
    """
    tool_calls = state['messages'][-1].additional_kwargs.get("tool_calls", [])
    for tool_call in tool_calls:
        if tool_call.get("function", {}).get("name") == "multiply":
            arguments = json.loads(tool_call.get("function").get("arguments"))
            result = multiply.invoke(arguments)
            return {"messages": [AIMessage(content=f"Tool Result: {result}")]}
    return {"messages": [AIMessage(content="No valid tool input provided.")]}

graph.add_node("tool", invoke_tool)
graph.add_edge("tool", END)
graph.set_entry_point("agent")

# Router Logic
def router(state):
    """
    Decide whether to invoke a tool or return the response.
    """
    tool_calls = state['messages'][-1].additional_kwargs.get("tool_calls", [])
    return "tool" if tool_calls else END

graph.add_conditional_edges("agent", router, {"tool": "tool", END: END})
app_graph = graph.compile()

# Save graph visualization as an image
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmpfile:
    graph_viz = app_graph.get_graph(xray=True)
    tmpfile.write(graph_viz.draw_mermaid_png())
    graph_image_path = tmpfile.name

# Streamlit Interface
st.title("Simple Tool Calling Demo")

# Display the workflow graph
st.image(graph_image_path, caption="Workflow Visualization")

tab1, tab2 = st.tabs(["Try Multiplication", "Ask General Queries"])

# Multiplication Tool Tab
with tab1:
    st.subheader("Try Multiplication")
    col1, col2 = st.columns(2)

    with col1:
        first_number = st.number_input("First Number", value=0, step=1)
    with col2:
        second_number = st.number_input("Second Number", value=0, step=1)

    if st.button("Multiply"):
        question = f"What is {first_number} * {second_number}?"
        try:
            output = app_graph.invoke({"messages": [HumanMessage(content=question)]})
            st.success(output['messages'][-1].content)
        except Exception as e:
            st.error(f"Error: {e}")

# General Query Tab
with tab2:
    st.subheader("General Query")
    user_input = st.text_input("Enter your question here")

    if st.button("Submit"):
        if user_input:
            try:
                # Pass the user input as a HumanMessage
                result = app_graph.invoke({"messages": [HumanMessage(content=user_input)]})
                st.write("Response:")
                st.success(result['messages'][-1].content)
            except Exception as e:
                st.error(f"Error: {e}")
        else:
            st.warning("Please enter a valid input.")

# Sidebar for References
st.sidebar.title("References")
st.sidebar.markdown("1. [LangGraph Tool Calling](https://github.com/aritrasen87/LLM_RAG_Model_Deployment/blob/main/LangGraph_02_ToolCalling.ipynb)")