File size: 3,616 Bytes
54343c0
 
 
e19b96d
012fbad
54343c0
 
 
 
1d4c700
54343c0
e19b96d
 
54343c0
e19b96d
54343c0
 
 
 
e19b96d
54343c0
 
 
 
e19b96d
1d4c700
e19b96d
293631b
1d4c700
54343c0
1d4c700
e19b96d
20d5bf2
 
 
 
 
 
 
54343c0
1d4c700
 
 
1a67ad9
1d4c700
 
e19b96d
 
 
df88f89
1d4c700
 
 
df88f89
1d4c700
e19b96d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import os
import json
import operator
import streamlit as st
import tempfile
from typing import TypedDict, Annotated, Sequence
from langchain_openai import ChatOpenAI
from langchain_core.tools import tool
from langchain_core.utils.function_calling import convert_to_openai_tool
from langgraph.graph import StateGraph, END

# Environment Setup
os.environ['OPENAI_API_KEY'] = os.getenv("OPENAI_API_KEY")

# Model Initialization
model = ChatOpenAI(temperature=0)

@tool
def multiply(first_number: int, second_number: int):
    """Multiplies two numbers together and returns the result."""
    return first_number * second_number

model_with_tools = model.bind(tools=[convert_to_openai_tool(multiply)])

# State Setup
class AgentState(TypedDict):
    messages: Annotated[Sequence, operator.add]

graph = StateGraph(AgentState)

def invoke_model(state):
    question = state['messages'][-1]
    response = model_with_tools.invoke(question)
    # If no tool calls are found, return the raw response content
    if not response.additional_kwargs.get("tool_calls", []):
        return {"messages": [response.content]}
    # Otherwise, return the response object as before
    return {"messages": [response]}


graph.add_node("agent", invoke_model)

def invoke_tool(state):
    tool_calls = state['messages'][-1].additional_kwargs.get("tool_calls", [])
    for tool_call in tool_calls:
        if tool_call.get("function").get("name") == "multiply":
            res = multiply.invoke(json.loads(tool_call.get("function").get("arguments")))
            return {"messages": [f"Tool Result: {res}"]}
    return {"messages": ["No tool input provided."]}

graph.add_node("tool", invoke_tool)
graph.add_edge("tool", END)
graph.set_entry_point("agent")

def router(state):
    calls = state['messages'][-1].additional_kwargs.get("tool_calls", [])
    return "multiply" if calls else "end"

graph.add_conditional_edges("agent", router, {"multiply": "tool", "end": END})
app_graph = graph.compile()

# Save graph visualization as an image
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmpfile:
    graph_viz = app_graph.get_graph(xray=True)
    tmpfile.write(graph_viz.draw_mermaid_png())  # Write binary image data to file
    graph_image_path = tmpfile.name

# Streamlit Interface
st.title("Simple Tool Calling Demo")

# Display the workflow graph
st.image(graph_image_path, caption="Workflow Visualization")

tab1, tab2 = st.tabs(["Try Multiplication", "Ask General Queries"])

with tab1:
    st.subheader("Try Multiplication")
    col1, col2 = st.columns(2)

    with col1:
        first_number = st.number_input("First Number", value=0, step=1)
    with col2:
        second_number = st.number_input("Second Number", value=0, step=1)

    if st.button("Multiply"):
        question = f"What is {first_number} * {second_number}?"
        output = app_graph.invoke({"messages": [question]})
        st.success(output['messages'][-1])

with tab2:
    st.subheader("General Query")
    user_input = st.text_input("Enter your question here")

    if st.button("Submit"):
        if user_input:
            try:
                result = app_graph.invoke({"messages": [user_input]})
                st.write("Response:")
                st.success(result['messages'][-1])
            except Exception as e:
                st.error("Something went wrong. Try again!")
        else:
            st.warning("Please enter a valid input.")


st.sidebar.title("References")
st.sidebar.markdown("1. [LangGraph Tool Calling](https://github.com/aritrasen87/LLM_RAG_Model_Deployment/blob/main/LangGraph_02_ToolCalling.ipynb)")