import os import json import operator import streamlit as st import tempfile from typing import TypedDict, Annotated, Sequence from langchain_openai import ChatOpenAI from langchain_core.tools import tool from langchain_core.messages import HumanMessage, AIMessage from langchain_core.utils.function_calling import convert_to_openai_tool from langgraph.graph import StateGraph, END # Environment Setup os.environ['OPENAI_API_KEY'] = os.getenv("OPENAI_API_KEY") # Model Initialization model = ChatOpenAI(temperature=0) # Define the tool @tool def multiply(first_number: int, second_number: int): """Multiplies two numbers together and returns the result.""" return first_number * second_number # Bind tool to model model_with_tools = model.bind(tools=[convert_to_openai_tool(multiply)]) # State Setup class AgentState(TypedDict): messages: Annotated[Sequence, operator.add] graph = StateGraph(AgentState) # Model Invocation def invoke_model(state): question = state['messages'][-1].content response = model_with_tools.invoke(question) return {"messages": [response]} graph.add_node("agent", invoke_model) # Tool Invocation def invoke_tool(state): tool_calls = state['messages'][-1].additional_kwargs.get("tool_calls", []) tool_results = [] for tool_call in tool_calls: if tool_call["function"]["name"] == "multiply": arguments = json.loads(tool_call["function"]["arguments"]) result = multiply.invoke(arguments) tool_results.append( AIMessage(content=f"Tool Result: {result}", additional_kwargs={"tool_call_id": tool_call["id"]}) ) return {"messages": tool_results or [AIMessage(content="No valid tool input provided.")]} graph.add_node("tool", invoke_tool) # Fixed Router Node def router(state): """ Route to tool if tool calls exist; otherwise END the workflow. """ tool_calls = state['messages'][-1].additional_kwargs.get("tool_calls", []) if tool_calls: return {"__state__": "tool"} # Direct routing to tool else: return {"__state__": END} # End workflow if no tool calls graph.add_node("router", router) # Add explicit edges graph.add_edge("agent", "router") graph.add_conditional_edges( "router", lambda state: "tool" if state['messages'][-1].additional_kwargs.get("tool_calls") else END, {"tool": "tool", END: END} ) graph.add_edge("tool", END) # Compile the graph graph.set_entry_point("agent") app_graph = graph.compile() # Save graph visualization with xray for visibility with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmpfile: graph_viz = app_graph.get_graph(xray=True) tmpfile.write(graph_viz.draw_mermaid_png()) graph_image_path = tmpfile.name # Streamlit Interface st.title("Simple Tool Calling Demo") st.image(graph_image_path, caption="Workflow Visualization") tab1, tab2 = st.tabs(["Try Multiplication", "Ask General Queries"]) with tab1: st.subheader("Try Multiplication") first_number = st.number_input("First Number", value=0, step=1) second_number = st.number_input("Second Number", value=0, step=1) if st.button("Multiply"): question = f"What is {first_number} * {second_number}?" try: output = app_graph.invoke({"messages": [HumanMessage(content=question)]}) st.success(output['messages'][-1].content) except Exception as e: st.error(f"Error: {e}") with tab2: st.subheader("General Query") user_input = st.text_input("Enter your question here") if st.button("Submit"): try: result = app_graph.invoke({"messages": [HumanMessage(content=user_input)]}) st.success(result['messages'][-1].content) except Exception as e: st.error(f"Error: {e}") # Sidebar for References st.sidebar.title("Reference:") st.sidebar.markdown("1. [LangGraph Tool Calling](https://github.com/aritrasen87/LLM_RAG_Model_Deployment/blob/main/LangGraph_02_ToolCalling.ipynb)")