DrishtiSharma commited on
Commit
5bd2e4f
·
verified ·
1 Parent(s): bced22e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -35
app.py CHANGED
@@ -33,55 +33,57 @@ graph = StateGraph(AgentState)
33
 
34
  # Model Invocation
35
  def invoke_model(state):
36
- question = state['messages'][-1].content
 
 
 
 
37
  response = model_with_tools.invoke(question)
 
 
 
 
 
 
 
 
 
 
38
  return {"messages": [response]}
39
 
40
  graph.add_node("agent", invoke_model)
41
 
42
  # Tool Invocation
43
  def invoke_tool(state):
 
 
 
44
  tool_calls = state['messages'][-1].additional_kwargs.get("tool_calls", [])
45
- tool_results = []
46
-
47
  for tool_call in tool_calls:
48
- if tool_call["function"]["name"] == "multiply":
49
- arguments = json.loads(tool_call["function"]["arguments"])
50
  result = multiply.invoke(arguments)
51
- tool_results.append(
52
- AIMessage(
53
- content=f"The result of multiplying {arguments['first_number']} and {arguments['second_number']} is {result}.",
54
- additional_kwargs={"tool_call_id": tool_call["id"]}
55
- )
56
- )
57
- return {"messages": tool_results or [AIMessage(content="No valid tool input provided.")]}
58
 
59
 
60
  graph.add_node("tool", invoke_tool)
 
 
61
 
62
- # Explicit Router Node
63
  def router(state):
64
  """
65
- Route to tool if tool calls exist; otherwise END the workflow.
66
  """
67
  tool_calls = state['messages'][-1].additional_kwargs.get("tool_calls", [])
68
- if tool_calls:
69
- return {"messages": [AIMessage(content="Routing to tool...")]}
70
- else:
71
- return {"messages": [AIMessage(content=state['messages'][-1].content)]}
72
 
73
- graph.add_node("router", router)
74
-
75
- # Add explicit edges
76
- graph.add_edge("agent", "router")
77
- graph.add_conditional_edges("router", lambda state: "tool" if state['messages'][-1].additional_kwargs.get("tool_calls") else END, {"tool": "tool", END: END})
78
- graph.add_edge("tool", END)
79
-
80
- # Compile the graph
81
- graph.set_entry_point("agent")
82
  app_graph = graph.compile()
83
 
84
- # Save graph visualization with xray for visibility
85
  with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmpfile:
86
  graph_viz = app_graph.get_graph(xray=True)
87
  tmpfile.write(graph_viz.draw_mermaid_png())
@@ -89,14 +91,21 @@ with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmpfile:
89
 
90
  # Streamlit Interface
91
  st.title("Simple Tool Calling Demo")
 
 
92
  st.image(graph_image_path, caption="Workflow Visualization")
93
 
94
  tab1, tab2 = st.tabs(["Try Multiplication", "Ask General Queries"])
95
 
 
96
  with tab1:
97
  st.subheader("Try Multiplication")
98
- first_number = st.number_input("First Number", value=0, step=1)
99
- second_number = st.number_input("Second Number", value=0, step=1)
 
 
 
 
100
 
101
  if st.button("Multiply"):
102
  question = f"What is {first_number} * {second_number}?"
@@ -106,16 +115,22 @@ with tab1:
106
  except Exception as e:
107
  st.error(f"Error: {e}")
108
 
 
109
  with tab2:
110
  st.subheader("General Query")
111
  user_input = st.text_input("Enter your question here")
112
 
113
  if st.button("Submit"):
114
- try:
115
- result = app_graph.invoke({"messages": [HumanMessage(content=user_input)]})
116
- st.success(result['messages'][-1].content)
117
- except Exception as e:
118
- st.error(f"Error: {e}")
 
 
 
 
 
119
 
120
  # Sidebar for References
121
  st.sidebar.title("References")
 
33
 
34
  # Model Invocation
35
  def invoke_model(state):
36
+ """
37
+ Invoke the model and handle tool invocation logic.
38
+ """
39
+ # Extract the question as a string
40
+ question = state['messages'][-1].content if isinstance(state['messages'][-1], HumanMessage) else state['messages'][-1]
41
  response = model_with_tools.invoke(question)
42
+
43
+ # If the response is plain text (no tool calls)
44
+ if isinstance(response, str):
45
+ return {"messages": [AIMessage(content=response)]}
46
+
47
+ # If no tool calls exist
48
+ if not response.additional_kwargs.get("tool_calls", []):
49
+ return {"messages": [AIMessage(content=response.content)]}
50
+
51
+ # If tool calls are present, return the full response
52
  return {"messages": [response]}
53
 
54
  graph.add_node("agent", invoke_model)
55
 
56
  # Tool Invocation
57
  def invoke_tool(state):
58
+ """
59
+ Invoke the 'multiply' tool if it's called by the model.
60
+ """
61
  tool_calls = state['messages'][-1].additional_kwargs.get("tool_calls", [])
 
 
62
  for tool_call in tool_calls:
63
+ if tool_call.get("function", {}).get("name") == "multiply":
64
+ arguments = json.loads(tool_call.get("function").get("arguments"))
65
  result = multiply.invoke(arguments)
66
+ # Append the result back to the state
67
+ return {"messages": state['messages'] + [AIMessage(content=f"Tool Result: {result}")]}
68
+ return {"messages": state['messages'] + [AIMessage(content="No valid tool input provided.")]}
 
 
 
 
69
 
70
 
71
  graph.add_node("tool", invoke_tool)
72
+ graph.add_edge("tool", END)
73
+ graph.set_entry_point("agent")
74
 
75
+ # Router Logic
76
  def router(state):
77
  """
78
+ Decide whether to invoke a tool or return the response.
79
  """
80
  tool_calls = state['messages'][-1].additional_kwargs.get("tool_calls", [])
81
+ return "tool" if tool_calls else END
 
 
 
82
 
83
+ graph.add_conditional_edges("agent", router, {"tool": "tool", END: END})
 
 
 
 
 
 
 
 
84
  app_graph = graph.compile()
85
 
86
+ # Save graph visualization as an image
87
  with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmpfile:
88
  graph_viz = app_graph.get_graph(xray=True)
89
  tmpfile.write(graph_viz.draw_mermaid_png())
 
91
 
92
  # Streamlit Interface
93
  st.title("Simple Tool Calling Demo")
94
+
95
+ # Display the workflow graph
96
  st.image(graph_image_path, caption="Workflow Visualization")
97
 
98
  tab1, tab2 = st.tabs(["Try Multiplication", "Ask General Queries"])
99
 
100
+ # Multiplication Tool Tab
101
  with tab1:
102
  st.subheader("Try Multiplication")
103
+ col1, col2 = st.columns(2)
104
+
105
+ with col1:
106
+ first_number = st.number_input("First Number", value=0, step=1)
107
+ with col2:
108
+ second_number = st.number_input("Second Number", value=0, step=1)
109
 
110
  if st.button("Multiply"):
111
  question = f"What is {first_number} * {second_number}?"
 
115
  except Exception as e:
116
  st.error(f"Error: {e}")
117
 
118
+ # General Query Tab
119
  with tab2:
120
  st.subheader("General Query")
121
  user_input = st.text_input("Enter your question here")
122
 
123
  if st.button("Submit"):
124
+ if user_input:
125
+ try:
126
+ # Pass the user input as a HumanMessage
127
+ result = app_graph.invoke({"messages": [HumanMessage(content=user_input)]})
128
+ st.write("Response:")
129
+ st.success(result['messages'][-1].content)
130
+ except Exception as e:
131
+ st.error(f"Error: {e}")
132
+ else:
133
+ st.warning("Please enter a valid input.")
134
 
135
  # Sidebar for References
136
  st.sidebar.title("References")