DrishtiSharma commited on
Commit
a5c39c6
·
verified ·
1 Parent(s): beab746

Update interim.py

Browse files
Files changed (1) hide show
  1. interim.py +50 -16
interim.py CHANGED
@@ -6,6 +6,7 @@ import tempfile
6
  from typing import TypedDict, Annotated, Sequence
7
  from langchain_openai import ChatOpenAI
8
  from langchain_core.tools import tool
 
9
  from langchain_core.utils.function_calling import convert_to_openai_tool
10
  from langgraph.graph import StateGraph, END
11
 
@@ -15,11 +16,13 @@ os.environ['OPENAI_API_KEY'] = os.getenv("OPENAI_API_KEY")
15
  # Model Initialization
16
  model = ChatOpenAI(temperature=0)
17
 
 
18
  @tool
19
  def multiply(first_number: int, second_number: int):
20
  """Multiplies two numbers together and returns the result."""
21
  return first_number * second_number
22
 
 
23
  model_with_tools = model.bind(tools=[convert_to_openai_tool(multiply)])
24
 
25
  # State Setup
@@ -28,35 +31,60 @@ class AgentState(TypedDict):
28
 
29
  graph = StateGraph(AgentState)
30
 
 
31
  def invoke_model(state):
32
- question = state['messages'][-1]
33
- return {"messages": [model_with_tools.invoke(question)]}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
  graph.add_node("agent", invoke_model)
36
 
 
37
  def invoke_tool(state):
 
 
 
38
  tool_calls = state['messages'][-1].additional_kwargs.get("tool_calls", [])
39
  for tool_call in tool_calls:
40
- if tool_call.get("function").get("name") == "multiply":
41
- res = multiply.invoke(json.loads(tool_call.get("function").get("arguments")))
42
- return {"messages": [f"Tool Result: {res}"]}
43
- return {"messages": ["No tool input provided."]}
 
44
 
45
  graph.add_node("tool", invoke_tool)
46
  graph.add_edge("tool", END)
47
  graph.set_entry_point("agent")
48
 
 
49
  def router(state):
50
- calls = state['messages'][-1].additional_kwargs.get("tool_calls", [])
51
- return "multiply" if calls else "end"
 
 
 
52
 
53
- graph.add_conditional_edges("agent", router, {"multiply": "tool", "end": END})
54
  app_graph = graph.compile()
55
 
56
  # Save graph visualization as an image
57
  with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmpfile:
58
  graph_viz = app_graph.get_graph(xray=True)
59
- tmpfile.write(graph_viz.draw_mermaid_png()) # Write binary image data to file
60
  graph_image_path = tmpfile.name
61
 
62
  # Streamlit Interface
@@ -67,6 +95,7 @@ st.image(graph_image_path, caption="Workflow Visualization")
67
 
68
  tab1, tab2 = st.tabs(["Try Multiplication", "Ask General Queries"])
69
 
 
70
  with tab1:
71
  st.subheader("Try Multiplication")
72
  col1, col2 = st.columns(2)
@@ -78,9 +107,13 @@ with tab1:
78
 
79
  if st.button("Multiply"):
80
  question = f"What is {first_number} * {second_number}?"
81
- output = app_graph.invoke({"messages": [question]})
82
- st.success(output['messages'][-1])
 
 
 
83
 
 
84
  with tab2:
85
  st.subheader("General Query")
86
  user_input = st.text_input("Enter your question here")
@@ -88,15 +121,16 @@ with tab2:
88
  if st.button("Submit"):
89
  if user_input:
90
  try:
91
- result = app_graph.invoke({"messages": [user_input]})
 
92
  st.write("Response:")
93
- st.success(result['messages'][-1])
94
  except Exception as e:
95
- st.error("Something went wrong. Try again!")
96
  else:
97
  st.warning("Please enter a valid input.")
98
 
99
-
100
  st.sidebar.title("References")
101
  st.sidebar.markdown("1. [LangGraph Tool Calling](https://github.com/aritrasen87/LLM_RAG_Model_Deployment/blob/main/LangGraph_02_ToolCalling.ipynb)")
102
 
 
6
  from typing import TypedDict, Annotated, Sequence
7
  from langchain_openai import ChatOpenAI
8
  from langchain_core.tools import tool
9
+ from langchain_core.messages import HumanMessage, AIMessage
10
  from langchain_core.utils.function_calling import convert_to_openai_tool
11
  from langgraph.graph import StateGraph, END
12
 
 
16
  # Model Initialization
17
  model = ChatOpenAI(temperature=0)
18
 
19
+ # Define the tool
20
  @tool
21
  def multiply(first_number: int, second_number: int):
22
  """Multiplies two numbers together and returns the result."""
23
  return first_number * second_number
24
 
25
+ # Bind tool to model
26
  model_with_tools = model.bind(tools=[convert_to_openai_tool(multiply)])
27
 
28
  # State Setup
 
31
 
32
  graph = StateGraph(AgentState)
33
 
34
+ # Model Invocation
35
  def invoke_model(state):
36
+ """
37
+ Invoke the model and handle tool invocation logic.
38
+ """
39
+ # Extract the question as a string
40
+ question = state['messages'][-1].content if isinstance(state['messages'][-1], HumanMessage) else state['messages'][-1]
41
+ response = model_with_tools.invoke(question)
42
+
43
+ # If the response is plain text (no tool calls)
44
+ if isinstance(response, str):
45
+ return {"messages": [AIMessage(content=response)]}
46
+
47
+ # If no tool calls exist
48
+ if not response.additional_kwargs.get("tool_calls", []):
49
+ return {"messages": [AIMessage(content=response.content)]}
50
+
51
+ # If tool calls are present, return the full response
52
+ return {"messages": [response]}
53
 
54
  graph.add_node("agent", invoke_model)
55
 
56
+ # Tool Invocation
57
  def invoke_tool(state):
58
+ """
59
+ Invoke the 'multiply' tool if it's called by the model.
60
+ """
61
  tool_calls = state['messages'][-1].additional_kwargs.get("tool_calls", [])
62
  for tool_call in tool_calls:
63
+ if tool_call.get("function", {}).get("name") == "multiply":
64
+ arguments = json.loads(tool_call.get("function").get("arguments"))
65
+ result = multiply.invoke(arguments)
66
+ return {"messages": [AIMessage(content=f"Tool Result: {result}")]}
67
+ return {"messages": [AIMessage(content="No valid tool input provided.")]}
68
 
69
  graph.add_node("tool", invoke_tool)
70
  graph.add_edge("tool", END)
71
  graph.set_entry_point("agent")
72
 
73
+ # Router Logic
74
  def router(state):
75
+ """
76
+ Decide whether to invoke a tool or return the response.
77
+ """
78
+ tool_calls = state['messages'][-1].additional_kwargs.get("tool_calls", [])
79
+ return "tool" if tool_calls else END
80
 
81
+ graph.add_conditional_edges("agent", router, {"tool": "tool", END: END})
82
  app_graph = graph.compile()
83
 
84
  # Save graph visualization as an image
85
  with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmpfile:
86
  graph_viz = app_graph.get_graph(xray=True)
87
+ tmpfile.write(graph_viz.draw_mermaid_png())
88
  graph_image_path = tmpfile.name
89
 
90
  # Streamlit Interface
 
95
 
96
  tab1, tab2 = st.tabs(["Try Multiplication", "Ask General Queries"])
97
 
98
+ # Multiplication Tool Tab
99
  with tab1:
100
  st.subheader("Try Multiplication")
101
  col1, col2 = st.columns(2)
 
107
 
108
  if st.button("Multiply"):
109
  question = f"What is {first_number} * {second_number}?"
110
+ try:
111
+ output = app_graph.invoke({"messages": [HumanMessage(content=question)]})
112
+ st.success(output['messages'][-1].content)
113
+ except Exception as e:
114
+ st.error(f"Error: {e}")
115
 
116
+ # General Query Tab
117
  with tab2:
118
  st.subheader("General Query")
119
  user_input = st.text_input("Enter your question here")
 
121
  if st.button("Submit"):
122
  if user_input:
123
  try:
124
+ # Pass the user input as a HumanMessage
125
+ result = app_graph.invoke({"messages": [HumanMessage(content=user_input)]})
126
  st.write("Response:")
127
+ st.success(result['messages'][-1].content)
128
  except Exception as e:
129
+ st.error(f"Error: {e}")
130
  else:
131
  st.warning("Please enter a valid input.")
132
 
133
+ # Sidebar for References
134
  st.sidebar.title("References")
135
  st.sidebar.markdown("1. [LangGraph Tool Calling](https://github.com/aritrasen87/LLM_RAG_Model_Deployment/blob/main/LangGraph_02_ToolCalling.ipynb)")
136