DrishtiSharma commited on
Commit
1d4c700
·
verified ·
1 Parent(s): 56ad039

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -64
app.py CHANGED
@@ -7,16 +7,14 @@ from typing import TypedDict, Annotated, Sequence
7
  from langchain_openai import ChatOpenAI
8
  from langchain_core.tools import tool
9
  from langchain_core.utils.function_calling import convert_to_openai_tool
10
- from langgraph.graph import StateGraph, END, START
11
- from langchain_core.messages import BaseMessage, HumanMessage, AIMessage
12
 
13
- # ------------------- Environment Setup -------------------
14
  os.environ['OPENAI_API_KEY'] = os.getenv("OPENAI_API_KEY")
15
 
16
- # ------------------- Model Initialization -------------------
17
  model = ChatOpenAI(temperature=0)
18
 
19
- # ------------------- Define the Tool -------------------
20
  @tool
21
  def multiply(first_number: int, second_number: int):
22
  """Multiplies two numbers together and returns the result."""
@@ -24,66 +22,49 @@ def multiply(first_number: int, second_number: int):
24
 
25
  model_with_tools = model.bind(tools=[convert_to_openai_tool(multiply)])
26
 
27
- # ------------------- State Setup -------------------
28
- class MessagesState(TypedDict):
29
- messages: Annotated[Sequence[BaseMessage], operator.add]
30
 
31
- # ------------------- Assistant Node -------------------
32
- def assistant(state: MessagesState):
33
- """Invoke the model to process messages."""
34
- messages = state['messages']
35
- response = model_with_tools.invoke(messages[-1])
36
- return {"messages": messages + [response]}
37
 
38
- # ------------------- Tools Node -------------------
39
- def tools(state: MessagesState):
40
- """Invoke tools based on tool calls."""
41
- tool_calls = state['messages'][-1].additional_kwargs.get("tool_calls", [])
42
- responses = []
43
 
44
- for tool_call in tool_calls:
45
- if tool_call["function"]["name"] == "multiply":
46
- args = json.loads(tool_call["function"]["arguments"])
47
- result = multiply.invoke(args)
48
- responses.append(
49
- AIMessage(content=f"Tool Result: {result}", name="multiply")
50
- )
51
- return {"messages": state["messages"] + responses}
52
-
53
- # ------------------- Router Logic -------------------
54
- def router(state: MessagesState):
55
  tool_calls = state['messages'][-1].additional_kwargs.get("tool_calls", [])
56
- return "tools" if tool_calls else "end"
 
 
 
 
57
 
58
- # ------------------- Graph Definition -------------------
59
- app_graph = StateGraph(MessagesState)
60
- app_graph.add_node("assistant", assistant)
61
- app_graph.add_node("tools", tools)
62
 
63
- # Define edges and conditional routing
64
- app_graph.add_edge(START, "assistant")
65
- app_graph.add_conditional_edges("assistant", router, {
66
- "tools": "tools",
67
- "end": END
68
- })
69
- app_graph.add_edge("tools", "assistant")
70
 
71
- # Compile the graph
72
- react_graph = app_graph.compile()
73
 
74
- # ------------------- Save Graph Visualization -------------------
75
  with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmpfile:
76
  graph_viz = app_graph.get_graph(xray=True)
77
  tmpfile.write(graph_viz.draw_mermaid_png()) # Write binary image data to file
78
  graph_image_path = tmpfile.name
79
 
80
- # ------------------- Streamlit Interface -------------------
81
  st.title("Simple Tool Calling Demo")
82
 
83
  # Display the workflow graph
84
  st.image(graph_image_path, caption="Workflow Visualization")
85
 
86
- # ------------------- Tab 1: Multiplication -------------------
87
  tab1, tab2 = st.tabs(["Try Multiplication", "Ask General Queries"])
88
 
89
  with tab1:
@@ -96,32 +77,25 @@ with tab1:
96
  second_number = st.number_input("Second Number", value=0, step=1)
97
 
98
  if st.button("Multiply"):
99
- question = HumanMessage(content=f"What is {first_number} * {second_number}?")
100
- try:
101
- output = react_graph.invoke({"messages": [question]})
102
- st.success(output['messages'][-1].content)
103
- except Exception as e:
104
- st.error(f"Error: {e}")
105
-
106
- # ------------------- Tab 2: General Queries -------------------
107
  with tab2:
108
  st.subheader("General Query")
109
  user_input = st.text_input("Enter your question here")
110
 
111
  if st.button("Submit"):
112
  if user_input:
113
- question = HumanMessage(content=user_input)
114
  try:
115
- result = react_graph.invoke({"messages": [question]})
116
  st.write("Response:")
117
- st.success(result['messages'][-1].content)
118
  except Exception as e:
119
- st.error(f"Error: {e}")
120
  else:
121
  st.warning("Please enter a valid input.")
122
 
123
- # ------------------- Sidebar Reference -------------------
124
- st.sidebar.title("Reference")
125
- st.sidebar.markdown(
126
- "1. [LangGraph Tool Calling](https://github.com/aritrasen87/LLM_RAG_Model_Deployment/blob/main/LangGraph_02_ToolCalling.ipynb)"
127
- )
 
7
  from langchain_openai import ChatOpenAI
8
  from langchain_core.tools import tool
9
  from langchain_core.utils.function_calling import convert_to_openai_tool
10
+ from langgraph.graph import StateGraph, END
 
11
 
12
+ # Environment Setup
13
  os.environ['OPENAI_API_KEY'] = os.getenv("OPENAI_API_KEY")
14
 
15
+ # Model Initialization
16
  model = ChatOpenAI(temperature=0)
17
 
 
18
  @tool
19
  def multiply(first_number: int, second_number: int):
20
  """Multiplies two numbers together and returns the result."""
 
22
 
23
  model_with_tools = model.bind(tools=[convert_to_openai_tool(multiply)])
24
 
25
+ # State Setup
26
+ class AgentState(TypedDict):
27
+ messages: Annotated[Sequence, operator.add]
28
 
29
+ graph = StateGraph(AgentState)
 
 
 
 
 
30
 
31
+ def invoke_model(state):
32
+ question = state['messages'][-1]
33
+ return {"messages": [model_with_tools.invoke(question)]}
 
 
34
 
35
+ graph.add_node("agent", invoke_model)
36
+
37
+ def invoke_tool(state):
 
 
 
 
 
 
 
 
38
  tool_calls = state['messages'][-1].additional_kwargs.get("tool_calls", [])
39
+ for tool_call in tool_calls:
40
+ if tool_call.get("function").get("name") == "multiply":
41
+ res = multiply.invoke(json.loads(tool_call.get("function").get("arguments")))
42
+ return {"messages": [f"Tool Result: {res}"]}
43
+ return {"messages": ["No tool input provided."]}
44
 
45
+ graph.add_node("tool", invoke_tool)
46
+ graph.add_edge("tool", END)
47
+ graph.set_entry_point("agent")
 
48
 
49
+ def router(state):
50
+ calls = state['messages'][-1].additional_kwargs.get("tool_calls", [])
51
+ return "multiply" if calls else "end"
 
 
 
 
52
 
53
+ graph.add_conditional_edges("agent", router, {"multiply": "tool", "end": END})
54
+ app_graph = graph.compile()
55
 
56
+ # Save graph visualization as an image
57
  with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmpfile:
58
  graph_viz = app_graph.get_graph(xray=True)
59
  tmpfile.write(graph_viz.draw_mermaid_png()) # Write binary image data to file
60
  graph_image_path = tmpfile.name
61
 
62
+ # Streamlit Interface
63
  st.title("Simple Tool Calling Demo")
64
 
65
  # Display the workflow graph
66
  st.image(graph_image_path, caption="Workflow Visualization")
67
 
 
68
  tab1, tab2 = st.tabs(["Try Multiplication", "Ask General Queries"])
69
 
70
  with tab1:
 
77
  second_number = st.number_input("Second Number", value=0, step=1)
78
 
79
  if st.button("Multiply"):
80
+ question = f"What is {first_number} * {second_number}?"
81
+ output = app_graph.invoke({"messages": [question]})
82
+ st.success(output['messages'][-1])
83
+
 
 
 
 
84
  with tab2:
85
  st.subheader("General Query")
86
  user_input = st.text_input("Enter your question here")
87
 
88
  if st.button("Submit"):
89
  if user_input:
 
90
  try:
91
+ result = app_graph.invoke({"messages": [user_input]})
92
  st.write("Response:")
93
+ st.success(result['messages'][-1])
94
  except Exception as e:
95
+ st.error("Something went wrong. Try again!")
96
  else:
97
  st.warning("Please enter a valid input.")
98
 
99
+
100
+ st.sidebar.title("References")
101
+ st.sidebar.markdown("1. [LangGraph Tool Calling](https://github.com/aritrasen87/LLM_RAG_Model_Deployment/blob/main/LangGraph_02_ToolCalling.ipynb)")