File size: 4,533 Bytes
54343c0 e19b96d 012fbad 54343c0 0a642a3 54343c0 1d4c700 54343c0 e19b96d 54343c0 e19b96d 54343c0 beab746 54343c0 e19b96d 54343c0 beab746 54343c0 e19b96d 1d4c700 e19b96d 293631b 1d4c700 54343c0 beab746 1d4c700 beab746 0a642a3 20d5bf2 beab746 0a642a3 beab746 54343c0 1d4c700 beab746 1d4c700 beab746 1a67ad9 1d4c700 0a642a3 141b6b8 0a642a3 beab746 141b6b8 0a642a3 df88f89 141b6b8 1d4c700 df88f89 0a642a3 1d4c700 beab746 0a642a3 e19b96d 0a642a3 e19b96d 0a642a3 e19b96d 0a642a3 beab746 e19b96d 0a642a3 e19b96d beab746 e19b96d beab746 e19b96d beab746 e19b96d beab746 e19b96d beab746 e19b96d beab746 e19b96d beab746 e19b96d beab746 0a642a3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
import os
import json
import operator
import streamlit as st
import tempfile
from typing import TypedDict, Annotated, Sequence
from langchain_openai import ChatOpenAI
from langchain_core.tools import tool
from langchain_core.messages import HumanMessage, ToolMessage
from langchain_core.utils.function_calling import convert_to_openai_tool
from langgraph.graph import StateGraph, END
# Environment Setup
os.environ['OPENAI_API_KEY'] = os.getenv("OPENAI_API_KEY")
# Model Initialization
model = ChatOpenAI(temperature=0)
# Define the tool
@tool
def multiply(first_number: int, second_number: int):
"""Multiplies two numbers together and returns the result."""
return first_number * second_number
# Bind tool to model
model_with_tools = model.bind(tools=[convert_to_openai_tool(multiply)])
# State Setup
class AgentState(TypedDict):
messages: Annotated[Sequence, operator.add]
graph = StateGraph(AgentState)
# Model Invocation
def invoke_model(state):
"""
Invoke the model and handle tool invocation logic.
"""
question = state['messages'][-1].content
response = model_with_tools.invoke(question)
# Return the model's response with tool_calls, if any
return {"messages": [response]}
graph.add_node("agent", invoke_model)
# Tool Invocation
def invoke_tool(state):
"""
Invoke the 'multiply' tool if it's called by the model.
"""
tool_calls = state['messages'][-1].additional_kwargs.get("tool_calls", [])
for tool_call in tool_calls:
if tool_call["function"]["name"] == "multiply":
# Extract and parse the arguments
arguments = json.loads(tool_call["function"]["arguments"])
result = multiply.invoke(arguments)
# Return the tool result while maintaining tool_call_id for future use
return {"messages": [ToolMessage(content=str(result), name="multiply", tool_call_id=tool_call["id"])]}
return {"messages": [ToolMessage(content="No valid tool input provided.")]}
graph.add_node("tool", invoke_tool)
# Router Node: Manual Addition
def router(state):
tool_calls = state['messages'][-1].additional_kwargs.get("tool_calls", [])
return {"messages": [ToolMessage(content="Routing...")]} # Dummy message for router node
graph.add_node("router", router)
# Add explicit edges from agent to router, and router to tool/END
graph.add_edge("agent", "router")
graph.add_conditional_edges("router", lambda state: "tool" if state['messages'][-1].additional_kwargs.get("tool_calls") else END, {"tool": "tool", END: END})
graph.add_edge("tool", END)
# Compile the graph
graph.set_entry_point("agent")
app_graph = graph.compile()
# Save graph visualization with router explicitly included
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmpfile:
graph_viz = app_graph.get_graph(xray=True) # Ensures detailed visualization
tmpfile.write(graph_viz.draw_mermaid_png())
graph_image_path = tmpfile.name
# Streamlit Interface
st.title("Simple Tool Calling Demo")
# Display the workflow graph
st.image(graph_image_path, caption="Workflow Visualization")
# Tabbed Interface
tab1, tab2 = st.tabs(["Try Multiplication", "Ask General Queries"])
# Multiplication Tool Tab
with tab1:
st.subheader("Try Multiplication")
col1, col2 = st.columns(2)
with col1:
first_number = st.number_input("First Number", value=0, step=1)
with col2:
second_number = st.number_input("Second Number", value=0, step=1)
if st.button("Multiply"):
question = f"What is {first_number} * {second_number}?"
try:
output = app_graph.invoke({"messages": [HumanMessage(content=question)]})
st.success(output['messages'][-1].content)
except Exception as e:
st.error(f"Error: {e}")
# General Query Tab
with tab2:
st.subheader("General Query")
user_input = st.text_input("Enter your question here")
if st.button("Submit"):
if user_input:
try:
result = app_graph.invoke({"messages": [HumanMessage(content=user_input)]})
st.write("Response:")
st.success(result['messages'][-1].content)
except Exception as e:
st.error(f"Error: {e}")
else:
st.warning("Please enter a valid input.")
# Sidebar for References
st.sidebar.title("References")
st.sidebar.markdown("1. [LangGraph Tool Calling](https://github.com/aritrasen87/LLM_RAG_Model_Deployment/blob/main/LangGraph_02_ToolCalling.ipynb)")
|