|
import os |
|
import json |
|
import operator |
|
import streamlit as st |
|
import tempfile |
|
from typing import TypedDict, Annotated, Sequence |
|
from langchain_openai import ChatOpenAI |
|
from langchain_core.tools import tool |
|
from langchain_core.utils.function_calling import convert_to_openai_tool |
|
from langgraph.graph import StateGraph, END |
|
|
|
|
|
os.environ['OPENAI_API_KEY'] = os.getenv("OPENAI_API_KEY") |
|
|
|
|
|
model = ChatOpenAI(temperature=0) |
|
|
|
@tool |
|
def multiply(first_number: int, second_number: int): |
|
"""Multiplies two numbers together and returns the result.""" |
|
return first_number * second_number |
|
|
|
model_with_tools = model.bind(tools=[convert_to_openai_tool(multiply)]) |
|
|
|
|
|
class AgentState(TypedDict): |
|
messages: Annotated[Sequence, operator.add] |
|
|
|
graph = StateGraph(AgentState) |
|
|
|
def invoke_model(state): |
|
question = state['messages'][-1] |
|
return {"messages": [model_with_tools.invoke(question)]} |
|
|
|
graph.add_node("agent", invoke_model) |
|
|
|
def invoke_tool(state): |
|
tool_calls = state['messages'][-1].additional_kwargs.get("tool_calls", []) |
|
for tool_call in tool_calls: |
|
if tool_call.get("function").get("name") == "multiply": |
|
res = multiply.invoke(json.loads(tool_call.get("function").get("arguments"))) |
|
return {"messages": [f"Tool Result: {res}"]} |
|
return {"messages": ["No tool input provided."]} |
|
|
|
graph.add_node("tool", invoke_tool) |
|
graph.add_edge("tool", END) |
|
graph.set_entry_point("agent") |
|
|
|
def router(state): |
|
calls = state['messages'][-1].additional_kwargs.get("tool_calls", []) |
|
return "multiply" if calls else "end" |
|
|
|
graph.add_conditional_edges("agent", router, {"multiply": "tool", "end": END}) |
|
app_graph = graph.compile() |
|
|
|
|
|
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmpfile: |
|
graph_viz = app_graph.get_graph(xray=True) |
|
tmpfile.write(graph_viz.draw_mermaid_png()) |
|
graph_image_path = tmpfile.name |
|
|
|
|
|
st.title("Simple Tool Calling Demo") |
|
|
|
|
|
st.image(graph_image_path, caption="Workflow Visualization") |
|
|
|
tab1, tab2 = st.tabs(["Try Multiplication", "Ask General Queries"]) |
|
|
|
with tab1: |
|
st.subheader("Try Multiplication") |
|
col1, col2 = st.columns(2) |
|
|
|
with col1: |
|
first_number = st.number_input("First Number", value=0, step=1) |
|
with col2: |
|
second_number = st.number_input("Second Number", value=0, step=1) |
|
|
|
if st.button("Multiply"): |
|
question = f"What is {first_number} * {second_number}?" |
|
output = app_graph.invoke({"messages": [question]}) |
|
st.success(output['messages'][-1]) |
|
|
|
with tab2: |
|
st.subheader("General Query") |
|
user_input = st.text_input("Enter your question here") |
|
|
|
if st.button("Submit"): |
|
if user_input: |
|
try: |
|
result = app_graph.invoke({"messages": [user_input]}) |
|
st.write("Response:") |
|
st.success(result['messages'][-1]) |
|
except Exception as e: |
|
st.error("Something went wrong. Try again!") |
|
else: |
|
st.warning("Please enter a valid input.") |
|
|
|
|
|
st.sidebar.title("References") |
|
st.sidebar.markdown("1. [LangGraph Tool Calling](https://github.com/aritrasen87/LLM_RAG_Model_Deployment/blob/main/LangGraph_02_ToolCalling.ipynb)") |
|
|
|
|