import os import json import operator import streamlit as st import tempfile from typing import TypedDict, Annotated, Sequence from langchain_openai import ChatOpenAI from langchain_core.tools import tool from langchain_core.messages import HumanMessage, AIMessage from langchain_core.utils.function_calling import convert_to_openai_tool from langgraph.graph import StateGraph, END # Environment Setup os.environ['OPENAI_API_KEY'] = os.getenv("OPENAI_API_KEY") # Model Initialization model = ChatOpenAI(temperature=0) # Define the tool @tool def multiply(first_number: int, second_number: int): """Multiplies two numbers together and returns the result.""" return first_number * second_number # Bind tool to model model_with_tools = model.bind(tools=[convert_to_openai_tool(multiply)]) # State Setup class AgentState(TypedDict): messages: Annotated[Sequence, operator.add] graph = StateGraph(AgentState) # Model Invocation def invoke_model(state): """ Invoke the model and handle tool invocation logic. """ # Extract the question as a string question = state['messages'][-1].content if isinstance(state['messages'][-1], HumanMessage) else state['messages'][-1] response = model_with_tools.invoke(question) # If the response is plain text (no tool calls) if isinstance(response, str): return {"messages": [AIMessage(content=response)]} # If no tool calls exist if not response.additional_kwargs.get("tool_calls", []): return {"messages": [AIMessage(content=response.content)]} # If tool calls are present, return the full response return {"messages": [response]} graph.add_node("agent", invoke_model) # Tool Invocation def invoke_tool(state): """ Invoke the 'multiply' tool if it's called by the model. """ tool_calls = state['messages'][-1].additional_kwargs.get("tool_calls", []) for tool_call in tool_calls: if tool_call.get("function", {}).get("name") == "multiply": arguments = json.loads(tool_call.get("function").get("arguments")) result = multiply.invoke(arguments) return {"messages": [AIMessage(content=f"Tool Result: {result}")]} return {"messages": [AIMessage(content="No valid tool input provided.")]} graph.add_node("tool", invoke_tool) graph.add_edge("tool", END) graph.set_entry_point("agent") # Router Logic def router(state): """ Decide whether to invoke a tool or return the response. """ tool_calls = state['messages'][-1].additional_kwargs.get("tool_calls", []) return "tool" if tool_calls else END graph.add_conditional_edges("agent", router, {"tool": "tool", END: END}) app_graph = graph.compile() # Save graph visualization as an image with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmpfile: graph_viz = app_graph.get_graph(xray=True) tmpfile.write(graph_viz.draw_mermaid_png()) graph_image_path = tmpfile.name # Streamlit Interface st.title("Simple Tool Calling Demo") # Display the workflow graph st.image(graph_image_path, caption="Workflow Visualization") tab1, tab2 = st.tabs(["Try Multiplication", "Ask General Queries"]) # Multiplication Tool Tab with tab1: st.subheader("Try Multiplication") col1, col2 = st.columns(2) with col1: first_number = st.number_input("First Number", value=0, step=1) with col2: second_number = st.number_input("Second Number", value=0, step=1) if st.button("Multiply"): question = f"What is {first_number} * {second_number}?" try: output = app_graph.invoke({"messages": [HumanMessage(content=question)]}) st.success(output['messages'][-1].content) except Exception as e: st.error(f"Error: {e}") # General Query Tab with tab2: st.subheader("General Query") user_input = st.text_input("Enter your question here") if st.button("Submit"): if user_input: try: # Pass the user input as a HumanMessage result = app_graph.invoke({"messages": [HumanMessage(content=user_input)]}) st.write("Response:") st.success(result['messages'][-1].content) except Exception as e: st.error(f"Error: {e}") else: st.warning("Please enter a valid input.") # Sidebar for References st.sidebar.title("References") st.sidebar.markdown("1. [LangGraph Tool Calling](https://github.com/aritrasen87/LLM_RAG_Model_Deployment/blob/main/LangGraph_02_ToolCalling.ipynb)")