DrishtiSharma commited on
Commit
beab746
·
verified ·
1 Parent(s): 20d5bf2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -20
app.py CHANGED
@@ -6,6 +6,7 @@ import tempfile
6
  from typing import TypedDict, Annotated, Sequence
7
  from langchain_openai import ChatOpenAI
8
  from langchain_core.tools import tool
 
9
  from langchain_core.utils.function_calling import convert_to_openai_tool
10
  from langgraph.graph import StateGraph, END
11
 
@@ -15,11 +16,13 @@ os.environ['OPENAI_API_KEY'] = os.getenv("OPENAI_API_KEY")
15
  # Model Initialization
16
  model = ChatOpenAI(temperature=0)
17
 
 
18
  @tool
19
  def multiply(first_number: int, second_number: int):
20
  """Multiplies two numbers together and returns the result."""
21
  return first_number * second_number
22
 
 
23
  model_with_tools = model.bind(tools=[convert_to_openai_tool(multiply)])
24
 
25
  # State Setup
@@ -28,41 +31,60 @@ class AgentState(TypedDict):
28
 
29
  graph = StateGraph(AgentState)
30
 
 
31
  def invoke_model(state):
32
- question = state['messages'][-1]
 
 
 
 
33
  response = model_with_tools.invoke(question)
34
- # If no tool calls are found, return the raw response content
 
 
 
 
 
35
  if not response.additional_kwargs.get("tool_calls", []):
36
- return {"messages": [response.content]}
37
- # Otherwise, return the response object as before
38
- return {"messages": [response]}
39
 
 
 
40
 
41
  graph.add_node("agent", invoke_model)
42
 
 
43
  def invoke_tool(state):
 
 
 
44
  tool_calls = state['messages'][-1].additional_kwargs.get("tool_calls", [])
45
  for tool_call in tool_calls:
46
- if tool_call.get("function").get("name") == "multiply":
47
- res = multiply.invoke(json.loads(tool_call.get("function").get("arguments")))
48
- return {"messages": [f"Tool Result: {res}"]}
49
- return {"messages": ["No tool input provided."]}
 
50
 
51
  graph.add_node("tool", invoke_tool)
52
  graph.add_edge("tool", END)
53
  graph.set_entry_point("agent")
54
 
 
55
  def router(state):
56
- calls = state['messages'][-1].additional_kwargs.get("tool_calls", [])
57
- return "multiply" if calls else "end"
 
 
 
58
 
59
- graph.add_conditional_edges("agent", router, {"multiply": "tool", "end": END})
60
  app_graph = graph.compile()
61
 
62
  # Save graph visualization as an image
63
  with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmpfile:
64
  graph_viz = app_graph.get_graph(xray=True)
65
- tmpfile.write(graph_viz.draw_mermaid_png()) # Write binary image data to file
66
  graph_image_path = tmpfile.name
67
 
68
  # Streamlit Interface
@@ -73,6 +95,7 @@ st.image(graph_image_path, caption="Workflow Visualization")
73
 
74
  tab1, tab2 = st.tabs(["Try Multiplication", "Ask General Queries"])
75
 
 
76
  with tab1:
77
  st.subheader("Try Multiplication")
78
  col1, col2 = st.columns(2)
@@ -84,9 +107,13 @@ with tab1:
84
 
85
  if st.button("Multiply"):
86
  question = f"What is {first_number} * {second_number}?"
87
- output = app_graph.invoke({"messages": [question]})
88
- st.success(output['messages'][-1])
 
 
 
89
 
 
90
  with tab2:
91
  st.subheader("General Query")
92
  user_input = st.text_input("Enter your question here")
@@ -94,14 +121,15 @@ with tab2:
94
  if st.button("Submit"):
95
  if user_input:
96
  try:
97
- result = app_graph.invoke({"messages": [user_input]})
 
98
  st.write("Response:")
99
- st.success(result['messages'][-1])
100
  except Exception as e:
101
- st.error("Something went wrong. Try again!")
102
  else:
103
  st.warning("Please enter a valid input.")
104
 
105
-
106
  st.sidebar.title("References")
107
- st.sidebar.markdown("1. [LangGraph Tool Calling](https://github.com/aritrasen87/LLM_RAG_Model_Deployment/blob/main/LangGraph_02_ToolCalling.ipynb)")
 
6
  from typing import TypedDict, Annotated, Sequence
7
  from langchain_openai import ChatOpenAI
8
  from langchain_core.tools import tool
9
+ from langchain_core.messages import HumanMessage, AIMessage
10
  from langchain_core.utils.function_calling import convert_to_openai_tool
11
  from langgraph.graph import StateGraph, END
12
 
 
16
  # Model Initialization
17
  model = ChatOpenAI(temperature=0)
18
 
19
+ # Define the tool
20
  @tool
21
  def multiply(first_number: int, second_number: int):
22
  """Multiplies two numbers together and returns the result."""
23
  return first_number * second_number
24
 
25
+ # Bind tool to model
26
  model_with_tools = model.bind(tools=[convert_to_openai_tool(multiply)])
27
 
28
  # State Setup
 
31
 
32
  graph = StateGraph(AgentState)
33
 
34
+ # Model Invocation
35
  def invoke_model(state):
36
+ """
37
+ Invoke the model and handle tool invocation logic.
38
+ """
39
+ # Extract the question as a string
40
+ question = state['messages'][-1].content if isinstance(state['messages'][-1], HumanMessage) else state['messages'][-1]
41
  response = model_with_tools.invoke(question)
42
+
43
+ # If the response is plain text (no tool calls)
44
+ if isinstance(response, str):
45
+ return {"messages": [AIMessage(content=response)]}
46
+
47
+ # If no tool calls exist
48
  if not response.additional_kwargs.get("tool_calls", []):
49
+ return {"messages": [AIMessage(content=response.content)]}
 
 
50
 
51
+ # If tool calls are present, return the full response
52
+ return {"messages": [response]}
53
 
54
  graph.add_node("agent", invoke_model)
55
 
56
+ # Tool Invocation
57
  def invoke_tool(state):
58
+ """
59
+ Invoke the 'multiply' tool if it's called by the model.
60
+ """
61
  tool_calls = state['messages'][-1].additional_kwargs.get("tool_calls", [])
62
  for tool_call in tool_calls:
63
+ if tool_call.get("function", {}).get("name") == "multiply":
64
+ arguments = json.loads(tool_call.get("function").get("arguments"))
65
+ result = multiply.invoke(arguments)
66
+ return {"messages": [AIMessage(content=f"Tool Result: {result}")]}
67
+ return {"messages": [AIMessage(content="No valid tool input provided.")]}
68
 
69
  graph.add_node("tool", invoke_tool)
70
  graph.add_edge("tool", END)
71
  graph.set_entry_point("agent")
72
 
73
+ # Router Logic
74
  def router(state):
75
+ """
76
+ Decide whether to invoke a tool or return the response.
77
+ """
78
+ tool_calls = state['messages'][-1].additional_kwargs.get("tool_calls", [])
79
+ return "tool" if tool_calls else END
80
 
81
+ graph.add_conditional_edges("agent", router, {"tool": "tool", END: END})
82
  app_graph = graph.compile()
83
 
84
  # Save graph visualization as an image
85
  with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmpfile:
86
  graph_viz = app_graph.get_graph(xray=True)
87
+ tmpfile.write(graph_viz.draw_mermaid_png())
88
  graph_image_path = tmpfile.name
89
 
90
  # Streamlit Interface
 
95
 
96
  tab1, tab2 = st.tabs(["Try Multiplication", "Ask General Queries"])
97
 
98
+ # Multiplication Tool Tab
99
  with tab1:
100
  st.subheader("Try Multiplication")
101
  col1, col2 = st.columns(2)
 
107
 
108
  if st.button("Multiply"):
109
  question = f"What is {first_number} * {second_number}?"
110
+ try:
111
+ output = app_graph.invoke({"messages": [HumanMessage(content=question)]})
112
+ st.success(output['messages'][-1].content)
113
+ except Exception as e:
114
+ st.error(f"Error: {e}")
115
 
116
+ # General Query Tab
117
  with tab2:
118
  st.subheader("General Query")
119
  user_input = st.text_input("Enter your question here")
 
121
  if st.button("Submit"):
122
  if user_input:
123
  try:
124
+ # Pass the user input as a HumanMessage
125
+ result = app_graph.invoke({"messages": [HumanMessage(content=user_input)]})
126
  st.write("Response:")
127
+ st.success(result['messages'][-1].content)
128
  except Exception as e:
129
+ st.error(f"Error: {e}")
130
  else:
131
  st.warning("Please enter a valid input.")
132
 
133
+ # Sidebar for References
134
  st.sidebar.title("References")
135
+ st.sidebar.markdown("1. [LangGraph Tool Calling](https://github.com/aritrasen87/LLM_RAG_Model_Deployment/blob/main/LangGraph_02_ToolCalling.ipynb)")