DrishtiSharma commited on
Commit
56ad039
·
verified ·
1 Parent(s): dd74f8b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -47
app.py CHANGED
@@ -7,16 +7,16 @@ from typing import TypedDict, Annotated, Sequence
7
  from langchain_openai import ChatOpenAI
8
  from langchain_core.tools import tool
9
  from langchain_core.utils.function_calling import convert_to_openai_tool
10
- from langgraph.graph import StateGraph, END
11
- from langchain_core.messages import BaseMessage
12
 
13
-
14
- # Environment Setup
15
  os.environ['OPENAI_API_KEY'] = os.getenv("OPENAI_API_KEY")
16
 
17
- # Model Initialization
18
  model = ChatOpenAI(temperature=0)
19
 
 
20
  @tool
21
  def multiply(first_number: int, second_number: int):
22
  """Multiplies two numbers together and returns the result."""
@@ -24,48 +24,43 @@ def multiply(first_number: int, second_number: int):
24
 
25
  model_with_tools = model.bind(tools=[convert_to_openai_tool(multiply)])
26
 
27
- # State Setup
28
- class AgentState(TypedDict):
29
- messages: Annotated[Sequence, operator.add]
30
-
31
- graph = StateGraph(AgentState)
32
-
33
- # Define MessagesState
34
  class MessagesState(TypedDict):
35
  messages: Annotated[Sequence[BaseMessage], operator.add]
36
 
37
- def invoke_model(state):
38
- question = state['messages'][-1]
39
- return {"messages": [model_with_tools.invoke(question)]}
 
 
 
40
 
41
- graph.add_node("agent", invoke_model)
42
-
43
- def invoke_tool(state):
44
  tool_calls = state['messages'][-1].additional_kwargs.get("tool_calls", [])
45
- for tool_call in tool_calls:
46
- if tool_call.get("function").get("name") == "multiply":
47
- res = multiply.invoke(json.loads(tool_call.get("function").get("arguments")))
48
- return {"messages": [f"Tool Result: {res}"]}
49
- return {"messages": ["No tool input provided."]}
50
 
51
- graph.add_node("tool", invoke_tool)
52
- graph.add_edge("tool", END)
53
- graph.set_entry_point("agent")
54
-
55
- # Define the router function
56
- def router(state):
 
 
 
 
 
57
  tool_calls = state['messages'][-1].additional_kwargs.get("tool_calls", [])
58
- if len(tool_calls):
59
- return "tools"
60
- else:
61
- return "end"
62
 
63
- # Add graph nodes and edges
64
  app_graph = StateGraph(MessagesState)
65
  app_graph.add_node("assistant", assistant)
66
- app_graph.add_node("tools", ToolNode(tools))
67
 
68
- # Add conditional routing logic
69
  app_graph.add_edge(START, "assistant")
70
  app_graph.add_conditional_edges("assistant", router, {
71
  "tools": "tools",
@@ -76,18 +71,19 @@ app_graph.add_edge("tools", "assistant")
76
  # Compile the graph
77
  react_graph = app_graph.compile()
78
 
79
- # Save graph visualization as an image
80
  with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmpfile:
81
  graph_viz = app_graph.get_graph(xray=True)
82
  tmpfile.write(graph_viz.draw_mermaid_png()) # Write binary image data to file
83
  graph_image_path = tmpfile.name
84
 
85
- # Streamlit Interface
86
  st.title("Simple Tool Calling Demo")
87
 
88
  # Display the workflow graph
89
  st.image(graph_image_path, caption="Workflow Visualization")
90
 
 
91
  tab1, tab2 = st.tabs(["Try Multiplication", "Ask General Queries"])
92
 
93
  with tab1:
@@ -100,25 +96,32 @@ with tab1:
100
  second_number = st.number_input("Second Number", value=0, step=1)
101
 
102
  if st.button("Multiply"):
103
- question = f"What is {first_number} * {second_number}?"
104
- output = app_graph.invoke({"messages": [question]})
105
- st.success(output['messages'][-1])
106
-
 
 
 
 
107
  with tab2:
108
  st.subheader("General Query")
109
  user_input = st.text_input("Enter your question here")
110
 
111
  if st.button("Submit"):
112
  if user_input:
 
113
  try:
114
- result = app_graph.invoke({"messages": [user_input]})
115
  st.write("Response:")
116
- st.success(result['messages'][-1])
117
  except Exception as e:
118
- st.error("Something went wrong. Try again!")
119
  else:
120
  st.warning("Please enter a valid input.")
121
 
122
-
123
  st.sidebar.title("Reference")
124
- st.sidebar.markdown("1. [LangGraph Tool Calling](https://github.com/aritrasen87/LLM_RAG_Model_Deployment/blob/main/LangGraph_02_ToolCalling.ipynb)")
 
 
 
7
  from langchain_openai import ChatOpenAI
8
  from langchain_core.tools import tool
9
  from langchain_core.utils.function_calling import convert_to_openai_tool
10
+ from langgraph.graph import StateGraph, END, START
11
+ from langchain_core.messages import BaseMessage, HumanMessage, AIMessage
12
 
13
+ # ------------------- Environment Setup -------------------
 
14
  os.environ['OPENAI_API_KEY'] = os.getenv("OPENAI_API_KEY")
15
 
16
+ # ------------------- Model Initialization -------------------
17
  model = ChatOpenAI(temperature=0)
18
 
19
+ # ------------------- Define the Tool -------------------
20
  @tool
21
  def multiply(first_number: int, second_number: int):
22
  """Multiplies two numbers together and returns the result."""
 
24
 
25
  model_with_tools = model.bind(tools=[convert_to_openai_tool(multiply)])
26
 
27
+ # ------------------- State Setup -------------------
 
 
 
 
 
 
28
  class MessagesState(TypedDict):
29
  messages: Annotated[Sequence[BaseMessage], operator.add]
30
 
31
+ # ------------------- Assistant Node -------------------
32
+ def assistant(state: MessagesState):
33
+ """Invoke the model to process messages."""
34
+ messages = state['messages']
35
+ response = model_with_tools.invoke(messages[-1])
36
+ return {"messages": messages + [response]}
37
 
38
+ # ------------------- Tools Node -------------------
39
+ def tools(state: MessagesState):
40
+ """Invoke tools based on tool calls."""
41
  tool_calls = state['messages'][-1].additional_kwargs.get("tool_calls", [])
42
+ responses = []
 
 
 
 
43
 
44
+ for tool_call in tool_calls:
45
+ if tool_call["function"]["name"] == "multiply":
46
+ args = json.loads(tool_call["function"]["arguments"])
47
+ result = multiply.invoke(args)
48
+ responses.append(
49
+ AIMessage(content=f"Tool Result: {result}", name="multiply")
50
+ )
51
+ return {"messages": state["messages"] + responses}
52
+
53
+ # ------------------- Router Logic -------------------
54
+ def router(state: MessagesState):
55
  tool_calls = state['messages'][-1].additional_kwargs.get("tool_calls", [])
56
+ return "tools" if tool_calls else "end"
 
 
 
57
 
58
+ # ------------------- Graph Definition -------------------
59
  app_graph = StateGraph(MessagesState)
60
  app_graph.add_node("assistant", assistant)
61
+ app_graph.add_node("tools", tools)
62
 
63
+ # Define edges and conditional routing
64
  app_graph.add_edge(START, "assistant")
65
  app_graph.add_conditional_edges("assistant", router, {
66
  "tools": "tools",
 
71
  # Compile the graph
72
  react_graph = app_graph.compile()
73
 
74
+ # ------------------- Save Graph Visualization -------------------
75
  with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmpfile:
76
  graph_viz = app_graph.get_graph(xray=True)
77
  tmpfile.write(graph_viz.draw_mermaid_png()) # Write binary image data to file
78
  graph_image_path = tmpfile.name
79
 
80
+ # ------------------- Streamlit Interface -------------------
81
  st.title("Simple Tool Calling Demo")
82
 
83
  # Display the workflow graph
84
  st.image(graph_image_path, caption="Workflow Visualization")
85
 
86
+ # ------------------- Tab 1: Multiplication -------------------
87
  tab1, tab2 = st.tabs(["Try Multiplication", "Ask General Queries"])
88
 
89
  with tab1:
 
96
  second_number = st.number_input("Second Number", value=0, step=1)
97
 
98
  if st.button("Multiply"):
99
+ question = HumanMessage(content=f"What is {first_number} * {second_number}?")
100
+ try:
101
+ output = react_graph.invoke({"messages": [question]})
102
+ st.success(output['messages'][-1].content)
103
+ except Exception as e:
104
+ st.error(f"Error: {e}")
105
+
106
+ # ------------------- Tab 2: General Queries -------------------
107
  with tab2:
108
  st.subheader("General Query")
109
  user_input = st.text_input("Enter your question here")
110
 
111
  if st.button("Submit"):
112
  if user_input:
113
+ question = HumanMessage(content=user_input)
114
  try:
115
+ result = react_graph.invoke({"messages": [question]})
116
  st.write("Response:")
117
+ st.success(result['messages'][-1].content)
118
  except Exception as e:
119
+ st.error(f"Error: {e}")
120
  else:
121
  st.warning("Please enter a valid input.")
122
 
123
+ # ------------------- Sidebar Reference -------------------
124
  st.sidebar.title("Reference")
125
+ st.sidebar.markdown(
126
+ "1. [LangGraph Tool Calling](https://github.com/aritrasen87/LLM_RAG_Model_Deployment/blob/main/LangGraph_02_ToolCalling.ipynb)"
127
+ )