DrishtiSharma commited on
Commit
963e565
·
verified ·
1 Parent(s): 141b6b8

Create test.py

Browse files
Files changed (1) hide show
  1. test.py +114 -0
test.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import operator
4
+ import streamlit as st
5
+ import tempfile
6
+ from typing import TypedDict, Annotated, Sequence
7
+ from langchain_openai import ChatOpenAI
8
+ from langchain_core.tools import tool
9
+ from langchain_core.messages import HumanMessage, AIMessage
10
+ from langchain_core.utils.function_calling import convert_to_openai_tool
11
+ from langgraph.graph import StateGraph, END
12
+
13
+ # Environment Setup
14
+ os.environ['OPENAI_API_KEY'] = os.getenv("OPENAI_API_KEY")
15
+
16
+ # Model Initialization
17
+ model = ChatOpenAI(temperature=0)
18
+
19
+ # Define the tool
20
+ @tool
21
+ def multiply(first_number: int, second_number: int):
22
+ """Multiplies two numbers together and returns the result."""
23
+ return first_number * second_number
24
+
25
+ # Bind tool to model
26
+ model_with_tools = model.bind(tools=[convert_to_openai_tool(multiply)])
27
+
28
+ # State Setup
29
+ class AgentState(TypedDict):
30
+ messages: Annotated[Sequence, operator.add]
31
+
32
+ graph = StateGraph(AgentState)
33
+
34
+ # Model Invocation
35
+ def invoke_model(state):
36
+ question = state['messages'][-1].content
37
+ response = model_with_tools.invoke(question)
38
+ return {"messages": [response]}
39
+
40
+ graph.add_node("agent", invoke_model)
41
+
42
+ # Tool Invocation
43
+ def invoke_tool(state):
44
+ tool_calls = state['messages'][-1].additional_kwargs.get("tool_calls", [])
45
+ tool_results = []
46
+
47
+ for tool_call in tool_calls:
48
+ if tool_call["function"]["name"] == "multiply":
49
+ arguments = json.loads(tool_call["function"]["arguments"])
50
+ result = multiply.invoke(arguments)
51
+ tool_results.append(
52
+ AIMessage(content=f"Tool Result: {result}", additional_kwargs={"tool_call_id": tool_call["id"]})
53
+ )
54
+ return {"messages": tool_results or [AIMessage(content="No valid tool input provided.")]}
55
+
56
+ graph.add_node("tool", invoke_tool)
57
+
58
+ # Explicit Router Node
59
+ def router(state):
60
+ """
61
+ Route to tool if tool calls exist; otherwise END the workflow.
62
+ """
63
+ tool_calls = state['messages'][-1].additional_kwargs.get("tool_calls", [])
64
+ if tool_calls:
65
+ return {"messages": [AIMessage(content="Routing to tool...")]}
66
+ else:
67
+ return {"messages": [AIMessage(content=state['messages'][-1].content)]}
68
+
69
+ graph.add_node("router", router)
70
+
71
+ # Add explicit edges
72
+ graph.add_edge("agent", "router")
73
+ graph.add_conditional_edges("router", lambda state: "tool" if state['messages'][-1].additional_kwargs.get("tool_calls") else END, {"tool": "tool", END: END})
74
+ graph.add_edge("tool", END)
75
+
76
+ # Compile the graph
77
+ graph.set_entry_point("agent")
78
+ app_graph = graph.compile()
79
+
80
+ # Save graph visualization with xray for visibility
81
+ with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmpfile:
82
+ graph_viz = app_graph.get_graph(xray=True)
83
+ tmpfile.write(graph_viz.draw_mermaid_png())
84
+ graph_image_path = tmpfile.name
85
+
86
+ # Streamlit Interface
87
+ st.title("Simple Tool Calling Demo")
88
+ st.image(graph_image_path, caption="Workflow Visualization")
89
+
90
+ tab1, tab2 = st.tabs(["Try Multiplication", "Ask General Queries"])
91
+
92
+ with tab1:
93
+ st.subheader("Try Multiplication")
94
+ first_number = st.number_input("First Number", value=0, step=1)
95
+ second_number = st.number_input("Second Number", value=0, step=1)
96
+
97
+ if st.button("Multiply"):
98
+ question = f"What is {first_number} * {second_number}?"
99
+ try:
100
+ output = app_graph.invoke({"messages": [HumanMessage(content=question)]})
101
+ st.success(output['messages'][-1].content)
102
+ except Exception as e:
103
+ st.error(f"Error: {e}")
104
+
105
+ with tab2:
106
+ st.subheader("General Query")
107
+ user_input = st.text_input("Enter your question here")
108
+
109
+ if st.button("Submit"):
110
+ try:
111
+ result = app_graph.invoke({"messages": [HumanMessage(content=user_input)]})
112
+ st.success(result['messages'][-1].content)
113
+ except Exception as e:
114
+ st.error(f"Error: {e}")