File size: 4,026 Bytes
54343c0 e19b96d 012fbad 54343c0 569f7b7 54343c0 1d4c700 54343c0 e19b96d 54343c0 e19b96d 54343c0 beab746 54343c0 e19b96d 54343c0 beab746 54343c0 e19b96d 1d4c700 e19b96d 293631b 1d4c700 54343c0 beab746 1d4c700 d3ffdbc 20d5bf2 beab746 54343c0 1d4c700 beab746 1d4c700 1a67ad9 d3ffdbc 1d4c700 d3ffdbc beab746 d3ffdbc bced22e 1d4c700 df88f89 d3ffdbc 1d4c700 569f7b7 d3ffdbc 569f7b7 beab746 d3ffdbc e19b96d d3ffdbc e19b96d d3ffdbc e19b96d 569f7b7 beab746 e19b96d d3ffdbc e19b96d beab746 e19b96d d3ffdbc e19b96d beab746 4a26df1 569f7b7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 |
import os
import json
import operator
import streamlit as st
import tempfile
from typing import TypedDict, Annotated, Sequence
from langchain_openai import ChatOpenAI
from langchain_core.tools import tool
from langchain_core.messages import HumanMessage, AIMessage
from langchain_core.utils.function_calling import convert_to_openai_tool
from langgraph.graph import StateGraph, END
# Environment Setup
os.environ['OPENAI_API_KEY'] = os.getenv("OPENAI_API_KEY")
# Model Initialization
model = ChatOpenAI(temperature=0)
# Define the tool
@tool
def multiply(first_number: int, second_number: int):
"""Multiplies two numbers together and returns the result."""
return first_number * second_number
# Bind tool to model
model_with_tools = model.bind(tools=[convert_to_openai_tool(multiply)])
# State Setup
class AgentState(TypedDict):
messages: Annotated[Sequence, operator.add]
graph = StateGraph(AgentState)
# Model Invocation
def invoke_model(state):
question = state['messages'][-1].content
response = model_with_tools.invoke(question)
return {"messages": [response]}
graph.add_node("agent", invoke_model)
# Tool Invocation
def invoke_tool(state):
tool_calls = state['messages'][-1].additional_kwargs.get("tool_calls", [])
tool_results = []
for tool_call in tool_calls:
if tool_call["function"]["name"] == "multiply":
arguments = json.loads(tool_call["function"]["arguments"])
result = multiply.invoke(arguments)
tool_results.append(
AIMessage(content=f"Tool Result: {result}", additional_kwargs={"tool_call_id": tool_call["id"]})
)
return {"messages": tool_results or [AIMessage(content="No valid tool input provided.")]}
graph.add_node("tool", invoke_tool)
# Fixed Router Node
def router(state):
"""
Route to tool if tool calls exist; otherwise END the workflow.
"""
tool_calls = state['messages'][-1].additional_kwargs.get("tool_calls", [])
if tool_calls:
return {"__state__": "tool"} # Direct routing to tool
else:
return {"__state__": END} # End workflow if no tool calls
graph.add_node("router", router)
# Add explicit edges
graph.add_edge("agent", "router")
graph.add_conditional_edges(
"router",
lambda state: "tool" if state['messages'][-1].additional_kwargs.get("tool_calls") else END,
{"tool": "tool", END: END}
)
graph.add_edge("tool", END)
# Compile the graph
graph.set_entry_point("agent")
app_graph = graph.compile()
# Save graph visualization with xray for visibility
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmpfile:
graph_viz = app_graph.get_graph(xray=True)
tmpfile.write(graph_viz.draw_mermaid_png())
graph_image_path = tmpfile.name
# Streamlit Interface
st.title("Simple Tool Calling Demo")
st.image(graph_image_path, caption="Workflow Visualization")
tab1, tab2 = st.tabs(["Try Multiplication", "Ask General Queries"])
with tab1:
st.subheader("Try Multiplication")
first_number = st.number_input("First Number", value=0, step=1)
second_number = st.number_input("Second Number", value=0, step=1)
if st.button("Multiply"):
question = f"What is {first_number} * {second_number}?"
try:
output = app_graph.invoke({"messages": [HumanMessage(content=question)]})
st.success(output['messages'][-1].content)
except Exception as e:
st.error(f"Error: {e}")
with tab2:
st.subheader("General Query")
user_input = st.text_input("Enter your question here")
if st.button("Submit"):
try:
result = app_graph.invoke({"messages": [HumanMessage(content=user_input)]})
st.success(result['messages'][-1].content)
except Exception as e:
st.error(f"Error: {e}")
# Sidebar for References
st.sidebar.title("Reference:")
st.sidebar.markdown("1. [LangGraph Tool Calling](https://github.com/aritrasen87/LLM_RAG_Model_Deployment/blob/main/LangGraph_02_ToolCalling.ipynb)") |