DrishtiSharma commited on
Commit
0a642a3
·
verified ·
1 Parent(s): a5c39c6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -27
app.py CHANGED
@@ -6,7 +6,7 @@ import tempfile
6
  from typing import TypedDict, Annotated, Sequence
7
  from langchain_openai import ChatOpenAI
8
  from langchain_core.tools import tool
9
- from langchain_core.messages import HumanMessage, AIMessage
10
  from langchain_core.utils.function_calling import convert_to_openai_tool
11
  from langgraph.graph import StateGraph, END
12
 
@@ -36,19 +36,10 @@ def invoke_model(state):
36
  """
37
  Invoke the model and handle tool invocation logic.
38
  """
39
- # Extract the question as a string
40
- question = state['messages'][-1].content if isinstance(state['messages'][-1], HumanMessage) else state['messages'][-1]
41
  response = model_with_tools.invoke(question)
42
 
43
- # If the response is plain text (no tool calls)
44
- if isinstance(response, str):
45
- return {"messages": [AIMessage(content=response)]}
46
-
47
- # If no tool calls exist
48
- if not response.additional_kwargs.get("tool_calls", []):
49
- return {"messages": [AIMessage(content=response.content)]}
50
-
51
- # If tool calls are present, return the full response
52
  return {"messages": [response]}
53
 
54
  graph.add_node("agent", invoke_model)
@@ -60,30 +51,33 @@ def invoke_tool(state):
60
  """
61
  tool_calls = state['messages'][-1].additional_kwargs.get("tool_calls", [])
62
  for tool_call in tool_calls:
63
- if tool_call.get("function", {}).get("name") == "multiply":
64
- arguments = json.loads(tool_call.get("function").get("arguments"))
65
  result = multiply.invoke(arguments)
66
- return {"messages": [AIMessage(content=f"Tool Result: {result}")]}
67
- return {"messages": [AIMessage(content="No valid tool input provided.")]}
68
 
69
  graph.add_node("tool", invoke_tool)
70
- graph.add_edge("tool", END)
71
- graph.set_entry_point("agent")
72
 
73
- # Router Logic
74
  def router(state):
75
- """
76
- Decide whether to invoke a tool or return the response.
77
- """
78
  tool_calls = state['messages'][-1].additional_kwargs.get("tool_calls", [])
79
- return "tool" if tool_calls else END
80
 
81
- graph.add_conditional_edges("agent", router, {"tool": "tool", END: END})
 
 
 
 
 
 
 
 
82
  app_graph = graph.compile()
83
 
84
- # Save graph visualization as an image
85
  with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmpfile:
86
- graph_viz = app_graph.get_graph(xray=True)
87
  tmpfile.write(graph_viz.draw_mermaid_png())
88
  graph_image_path = tmpfile.name
89
 
@@ -93,6 +87,7 @@ st.title("Simple Tool Calling Demo")
93
  # Display the workflow graph
94
  st.image(graph_image_path, caption="Workflow Visualization")
95
 
 
96
  tab1, tab2 = st.tabs(["Try Multiplication", "Ask General Queries"])
97
 
98
  # Multiplication Tool Tab
@@ -121,7 +116,6 @@ with tab2:
121
  if st.button("Submit"):
122
  if user_input:
123
  try:
124
- # Pass the user input as a HumanMessage
125
  result = app_graph.invoke({"messages": [HumanMessage(content=user_input)]})
126
  st.write("Response:")
127
  st.success(result['messages'][-1].content)
@@ -133,3 +127,4 @@ with tab2:
133
  # Sidebar for References
134
  st.sidebar.title("References")
135
  st.sidebar.markdown("1. [LangGraph Tool Calling](https://github.com/aritrasen87/LLM_RAG_Model_Deployment/blob/main/LangGraph_02_ToolCalling.ipynb)")
 
 
6
  from typing import TypedDict, Annotated, Sequence
7
  from langchain_openai import ChatOpenAI
8
  from langchain_core.tools import tool
9
+ from langchain_core.messages import HumanMessage, ToolMessage
10
  from langchain_core.utils.function_calling import convert_to_openai_tool
11
  from langgraph.graph import StateGraph, END
12
 
 
36
  """
37
  Invoke the model and handle tool invocation logic.
38
  """
39
+ question = state['messages'][-1].content
 
40
  response = model_with_tools.invoke(question)
41
 
42
+ # Return the model's response with tool_calls, if any
 
 
 
 
 
 
 
 
43
  return {"messages": [response]}
44
 
45
  graph.add_node("agent", invoke_model)
 
51
  """
52
  tool_calls = state['messages'][-1].additional_kwargs.get("tool_calls", [])
53
  for tool_call in tool_calls:
54
+ if tool_call["function"]["name"] == "multiply":
55
+ arguments = json.loads(tool_call["function"]["arguments"])
56
  result = multiply.invoke(arguments)
57
+ return {"messages": [ToolMessage(content=str(result))]}
58
+ return {"messages": [ToolMessage(content="No valid tool input provided.")]}
59
 
60
  graph.add_node("tool", invoke_tool)
 
 
61
 
62
+ # Router Node: Manual Addition
63
  def router(state):
 
 
 
64
  tool_calls = state['messages'][-1].additional_kwargs.get("tool_calls", [])
65
+ return {"messages": [ToolMessage(content="Routing...")]} # Dummy message for router node
66
 
67
+ graph.add_node("router", router)
68
+
69
+ # Add explicit edges from agent to router, and router to tool/END
70
+ graph.add_edge("agent", "router")
71
+ graph.add_conditional_edges("router", lambda state: "tool" if state['messages'][-1].additional_kwargs.get("tool_calls") else END, {"tool": "tool", END: END})
72
+ graph.add_edge("tool", END)
73
+
74
+ # Compile the graph
75
+ graph.set_entry_point("agent")
76
  app_graph = graph.compile()
77
 
78
+ # Save graph visualization with router explicitly included
79
  with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmpfile:
80
+ graph_viz = app_graph.get_graph(xray=True) # Ensures detailed visualization
81
  tmpfile.write(graph_viz.draw_mermaid_png())
82
  graph_image_path = tmpfile.name
83
 
 
87
  # Display the workflow graph
88
  st.image(graph_image_path, caption="Workflow Visualization")
89
 
90
+ # Tabbed Interface
91
  tab1, tab2 = st.tabs(["Try Multiplication", "Ask General Queries"])
92
 
93
  # Multiplication Tool Tab
 
116
  if st.button("Submit"):
117
  if user_input:
118
  try:
 
119
  result = app_graph.invoke({"messages": [HumanMessage(content=user_input)]})
120
  st.write("Response:")
121
  st.success(result['messages'][-1].content)
 
127
  # Sidebar for References
128
  st.sidebar.title("References")
129
  st.sidebar.markdown("1. [LangGraph Tool Calling](https://github.com/aritrasen87/LLM_RAG_Model_Deployment/blob/main/LangGraph_02_ToolCalling.ipynb)")
130
+