DrishtiSharma commited on
Commit
569f7b7
·
verified ·
1 Parent(s): 963e565

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -41
app.py CHANGED
@@ -6,7 +6,7 @@ import tempfile
6
  from typing import TypedDict, Annotated, Sequence
7
  from langchain_openai import ChatOpenAI
8
  from langchain_core.tools import tool
9
- from langchain_core.messages import HumanMessage, ToolMessage
10
  from langchain_core.utils.function_calling import convert_to_openai_tool
11
  from langgraph.graph import StateGraph, END
12
 
@@ -33,43 +33,42 @@ graph = StateGraph(AgentState)
33
 
34
  # Model Invocation
35
  def invoke_model(state):
36
- """
37
- Invoke the model and handle tool invocation logic.
38
- """
39
  question = state['messages'][-1].content
40
  response = model_with_tools.invoke(question)
41
-
42
- # Return the model's response with tool_calls, if any
43
  return {"messages": [response]}
44
 
45
  graph.add_node("agent", invoke_model)
46
 
47
  # Tool Invocation
48
  def invoke_tool(state):
49
- """
50
- Invoke the 'multiply' tool if it's called by the model.
51
- """
52
  tool_calls = state['messages'][-1].additional_kwargs.get("tool_calls", [])
 
 
53
  for tool_call in tool_calls:
54
  if tool_call["function"]["name"] == "multiply":
55
- # Extract and parse the arguments
56
  arguments = json.loads(tool_call["function"]["arguments"])
57
  result = multiply.invoke(arguments)
58
- # Return the tool result while maintaining tool_call_id for future use
59
- return {"messages": [ToolMessage(content=str(result), name="multiply", tool_call_id=tool_call["id"])]}
60
- return {"messages": [ToolMessage(content="No valid tool input provided.")]}
61
-
62
 
63
  graph.add_node("tool", invoke_tool)
64
 
65
- # Router Node: Manual Addition
66
  def router(state):
 
 
 
67
  tool_calls = state['messages'][-1].additional_kwargs.get("tool_calls", [])
68
- return {"messages": [ToolMessage(content="Routing...")]} # Dummy message for router node
 
 
 
69
 
70
  graph.add_node("router", router)
71
 
72
- # Add explicit edges from agent to router, and router to tool/END
73
  graph.add_edge("agent", "router")
74
  graph.add_conditional_edges("router", lambda state: "tool" if state['messages'][-1].additional_kwargs.get("tool_calls") else END, {"tool": "tool", END: END})
75
  graph.add_edge("tool", END)
@@ -78,30 +77,22 @@ graph.add_edge("tool", END)
78
  graph.set_entry_point("agent")
79
  app_graph = graph.compile()
80
 
81
- # Save graph visualization with router explicitly included
82
  with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmpfile:
83
- graph_viz = app_graph.get_graph(xray=True) # Ensures detailed visualization
84
  tmpfile.write(graph_viz.draw_mermaid_png())
85
  graph_image_path = tmpfile.name
86
 
87
  # Streamlit Interface
88
  st.title("Simple Tool Calling Demo")
89
-
90
- # Display the workflow graph
91
  st.image(graph_image_path, caption="Workflow Visualization")
92
 
93
- # Tabbed Interface
94
  tab1, tab2 = st.tabs(["Try Multiplication", "Ask General Queries"])
95
 
96
- # Multiplication Tool Tab
97
  with tab1:
98
  st.subheader("Try Multiplication")
99
- col1, col2 = st.columns(2)
100
-
101
- with col1:
102
- first_number = st.number_input("First Number", value=0, step=1)
103
- with col2:
104
- second_number = st.number_input("Second Number", value=0, step=1)
105
 
106
  if st.button("Multiply"):
107
  question = f"What is {first_number} * {second_number}?"
@@ -111,23 +102,17 @@ with tab1:
111
  except Exception as e:
112
  st.error(f"Error: {e}")
113
 
114
- # General Query Tab
115
  with tab2:
116
  st.subheader("General Query")
117
  user_input = st.text_input("Enter your question here")
118
 
119
  if st.button("Submit"):
120
- if user_input:
121
- try:
122
- result = app_graph.invoke({"messages": [HumanMessage(content=user_input)]})
123
- st.write("Response:")
124
- st.success(result['messages'][-1].content)
125
- except Exception as e:
126
- st.error(f"Error: {e}")
127
- else:
128
- st.warning("Please enter a valid input.")
129
 
130
  # Sidebar for References
131
  st.sidebar.title("References")
132
- st.sidebar.markdown("1. [LangGraph Tool Calling](https://github.com/aritrasen87/LLM_RAG_Model_Deployment/blob/main/LangGraph_02_ToolCalling.ipynb)")
133
-
 
6
  from typing import TypedDict, Annotated, Sequence
7
  from langchain_openai import ChatOpenAI
8
  from langchain_core.tools import tool
9
+ from langchain_core.messages import HumanMessage, AIMessage
10
  from langchain_core.utils.function_calling import convert_to_openai_tool
11
  from langgraph.graph import StateGraph, END
12
 
 
33
 
34
  # Model Invocation
35
  def invoke_model(state):
 
 
 
36
  question = state['messages'][-1].content
37
  response = model_with_tools.invoke(question)
 
 
38
  return {"messages": [response]}
39
 
40
  graph.add_node("agent", invoke_model)
41
 
42
  # Tool Invocation
43
  def invoke_tool(state):
 
 
 
44
  tool_calls = state['messages'][-1].additional_kwargs.get("tool_calls", [])
45
+ tool_results = []
46
+
47
  for tool_call in tool_calls:
48
  if tool_call["function"]["name"] == "multiply":
 
49
  arguments = json.loads(tool_call["function"]["arguments"])
50
  result = multiply.invoke(arguments)
51
+ tool_results.append(
52
+ AIMessage(content=f"Tool Result: {result}", additional_kwargs={"tool_call_id": tool_call["id"]})
53
+ )
54
+ return {"messages": tool_results or [AIMessage(content="No valid tool input provided.")]}
55
 
56
  graph.add_node("tool", invoke_tool)
57
 
58
+ # Explicit Router Node
59
  def router(state):
60
+ """
61
+ Route to tool if tool calls exist; otherwise END the workflow.
62
+ """
63
  tool_calls = state['messages'][-1].additional_kwargs.get("tool_calls", [])
64
+ if tool_calls:
65
+ return {"messages": [AIMessage(content="Routing to tool...")]}
66
+ else:
67
+ return {"messages": [AIMessage(content=state['messages'][-1].content)]}
68
 
69
  graph.add_node("router", router)
70
 
71
+ # Add explicit edges
72
  graph.add_edge("agent", "router")
73
  graph.add_conditional_edges("router", lambda state: "tool" if state['messages'][-1].additional_kwargs.get("tool_calls") else END, {"tool": "tool", END: END})
74
  graph.add_edge("tool", END)
 
77
  graph.set_entry_point("agent")
78
  app_graph = graph.compile()
79
 
80
+ # Save graph visualization with xray for visibility
81
  with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmpfile:
82
+ graph_viz = app_graph.get_graph(xray=True)
83
  tmpfile.write(graph_viz.draw_mermaid_png())
84
  graph_image_path = tmpfile.name
85
 
86
  # Streamlit Interface
87
  st.title("Simple Tool Calling Demo")
 
 
88
  st.image(graph_image_path, caption="Workflow Visualization")
89
 
 
90
  tab1, tab2 = st.tabs(["Try Multiplication", "Ask General Queries"])
91
 
 
92
  with tab1:
93
  st.subheader("Try Multiplication")
94
+ first_number = st.number_input("First Number", value=0, step=1)
95
+ second_number = st.number_input("Second Number", value=0, step=1)
 
 
 
 
96
 
97
  if st.button("Multiply"):
98
  question = f"What is {first_number} * {second_number}?"
 
102
  except Exception as e:
103
  st.error(f"Error: {e}")
104
 
 
105
  with tab2:
106
  st.subheader("General Query")
107
  user_input = st.text_input("Enter your question here")
108
 
109
  if st.button("Submit"):
110
+ try:
111
+ result = app_graph.invoke({"messages": [HumanMessage(content=user_input)]})
112
+ st.success(result['messages'][-1].content)
113
+ except Exception as e:
114
+ st.error(f"Error: {e}")
 
 
 
 
115
 
116
  # Sidebar for References
117
  st.sidebar.title("References")
118
+ st.sidebar.markdown("1. [LangGraph Tool Calling](https://github.com/aritrasen87/LLM_RAG_Model_Deployment/blob/main/LangGraph_02_ToolCalling.ipynb)")