import os import json import operator import streamlit as st import tempfile from typing import TypedDict, Annotated, Sequence from langchain_openai import ChatOpenAI from langchain_core.tools import tool from langchain_core.messages import HumanMessage, ToolMessage from langchain_core.utils.function_calling import convert_to_openai_tool from langgraph.graph import StateGraph, END # Environment Setup os.environ['OPENAI_API_KEY'] = os.getenv("OPENAI_API_KEY") # Model Initialization model = ChatOpenAI(temperature=0) # Define the tool @tool def multiply(first_number: int, second_number: int): """Multiplies two numbers together and returns the result.""" return first_number * second_number # Bind tool to model model_with_tools = model.bind(tools=[convert_to_openai_tool(multiply)]) # State Setup class AgentState(TypedDict): messages: Annotated[Sequence, operator.add] graph = StateGraph(AgentState) # Model Invocation def invoke_model(state): """ Invoke the model and handle tool invocation logic. """ question = state['messages'][-1].content response = model_with_tools.invoke(question) # Return the model's response with tool_calls, if any return {"messages": [response]} graph.add_node("agent", invoke_model) # Tool Invocation def invoke_tool(state): """ Invoke the 'multiply' tool if it's called by the model. """ tool_calls = state['messages'][-1].additional_kwargs.get("tool_calls", []) for tool_call in tool_calls: if tool_call["function"]["name"] == "multiply": # Extract and parse the arguments arguments = json.loads(tool_call["function"]["arguments"]) result = multiply.invoke(arguments) # Return the tool result while maintaining tool_call_id for future use return {"messages": [ToolMessage(content=str(result), name="multiply", tool_call_id=tool_call["id"])]} return {"messages": [ToolMessage(content="No valid tool input provided.")]} graph.add_node("tool", invoke_tool) # Router Node: Manual Addition def router(state): tool_calls = state['messages'][-1].additional_kwargs.get("tool_calls", []) return {"messages": [ToolMessage(content="Routing...")]} # Dummy message for router node graph.add_node("router", router) # Add explicit edges from agent to router, and router to tool/END graph.add_edge("agent", "router") graph.add_conditional_edges("router", lambda state: "tool" if state['messages'][-1].additional_kwargs.get("tool_calls") else END, {"tool": "tool", END: END}) graph.add_edge("tool", END) # Compile the graph graph.set_entry_point("agent") app_graph = graph.compile() # Save graph visualization with router explicitly included with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmpfile: graph_viz = app_graph.get_graph(xray=True) # Ensures detailed visualization tmpfile.write(graph_viz.draw_mermaid_png()) graph_image_path = tmpfile.name # Streamlit Interface st.title("Simple Tool Calling Demo") # Display the workflow graph st.image(graph_image_path, caption="Workflow Visualization") # Tabbed Interface tab1, tab2 = st.tabs(["Try Multiplication", "Ask General Queries"]) # Multiplication Tool Tab with tab1: st.subheader("Try Multiplication") col1, col2 = st.columns(2) with col1: first_number = st.number_input("First Number", value=0, step=1) with col2: second_number = st.number_input("Second Number", value=0, step=1) if st.button("Multiply"): question = f"What is {first_number} * {second_number}?" try: output = app_graph.invoke({"messages": [HumanMessage(content=question)]}) st.success(output['messages'][-1].content) except Exception as e: st.error(f"Error: {e}") # General Query Tab with tab2: st.subheader("General Query") user_input = st.text_input("Enter your question here") if st.button("Submit"): if user_input: try: result = app_graph.invoke({"messages": [HumanMessage(content=user_input)]}) st.write("Response:") st.success(result['messages'][-1].content) except Exception as e: st.error(f"Error: {e}") else: st.warning("Please enter a valid input.") # Sidebar for References st.sidebar.title("References") st.sidebar.markdown("1. [LangGraph Tool Calling](https://github.com/aritrasen87/LLM_RAG_Model_Deployment/blob/main/LangGraph_02_ToolCalling.ipynb)")