DrishtiSharma commited on
Commit
e19b96d
·
verified ·
1 Parent(s): 5472503

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +63 -80
app.py CHANGED
@@ -1,118 +1,101 @@
1
  import os
2
  import json
3
  import operator
 
4
  import tempfile
5
  from typing import TypedDict, Annotated, Sequence
6
- from dotenv import load_dotenv
7
-
8
  from langchain_openai import ChatOpenAI
9
- from langchain_core.messages import BaseMessage
10
  from langchain_core.tools import tool
11
  from langchain_core.utils.function_calling import convert_to_openai_tool
12
  from langgraph.graph import StateGraph, END
13
 
14
- import streamlit as st
15
-
16
- # ------------------- Streamlit UI Layout -------------------
17
- #st.set_page_config(page_title="Streamlit LLM Graph", layout="wide")
18
- st.title("Test App- Tool Calling and Conditional Graph")
19
-
20
- # ------------------- Environment Setup -------------------
21
 
22
- OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
23
- if not OPENAI_API_KEY:
24
- st.error("OpenAI API Key not found! Please set it in your environment variables.")
25
- st.stop()
26
-
27
- # ------------------- Model Initialization -------------------
28
  model = ChatOpenAI(temperature=0)
29
 
30
- # ------------------- Tool Definition -------------------
31
  @tool
32
  def multiply(first_number: int, second_number: int):
33
- """Multiplies two numbers together."""
34
  return first_number * second_number
35
 
36
  model_with_tools = model.bind(tools=[convert_to_openai_tool(multiply)])
37
 
38
- # ------------------- State Graph Setup -------------------
39
  class AgentState(TypedDict):
40
- messages: Annotated[Sequence[BaseMessage], operator.add]
41
 
42
  graph = StateGraph(AgentState)
43
 
44
  def invoke_model(state):
45
- messages = state['messages']
46
- question = messages[-1]
47
  return {"messages": [model_with_tools.invoke(question)]}
48
 
49
  graph.add_node("agent", invoke_model)
50
 
51
  def invoke_tool(state):
52
  tool_calls = state['messages'][-1].additional_kwargs.get("tool_calls", [])
53
- multiply_call = None
54
-
55
  for tool_call in tool_calls:
56
  if tool_call.get("function").get("name") == "multiply":
57
- multiply_call = tool_call
58
-
59
- if multiply_call is None:
60
- raise Exception("No multiply input found.")
61
-
62
- res = multiply.invoke(
63
- json.loads(multiply_call.get("function").get("arguments"))
64
- )
65
- return {"messages": [res]}
66
 
67
  graph.add_node("tool", invoke_tool)
68
  graph.add_edge("tool", END)
69
  graph.set_entry_point("agent")
70
 
71
  def router(state):
72
- tool_calls = state['messages'][-1].additional_kwargs.get("tool_calls", [])
73
- if len(tool_calls):
74
- return "multiply"
75
- else:
76
- return "end"
77
-
78
- graph.add_conditional_edges("agent", router, {
79
- "multiply": "tool",
80
- "end": END,
81
- })
82
-
83
- app = graph.compile()
84
-
85
- # ------------------- Graph Visualization -------------------
86
- st.header("Graph Visualization")
87
-
88
- try:
89
- # Generate graph visualization and save as image
90
- with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmpfile:
91
- graph_viz = app.get_graph(xray=True)
92
- tmpfile.write(graph_viz.draw_mermaid_png()) # Write binary image data to file
93
- graph_image_path = tmpfile.name
94
-
95
- # Display the graph image
96
- st.image(graph_image_path, caption="State Graph Visualization", use_column_width=True)
97
- except Exception as e:
98
- st.error(f"Error generating graph visualization: {e}")
99
-
100
- # ------------------- Streamlit Interaction -------------------
101
- st.header("Ask the LLM")
102
-
103
- # User Input
104
- user_query = st.text_input("Enter your query:", placeholder="E.g., What is 123 * 456?")
105
- submit_button = st.button("Submit")
106
-
107
- # Process Query
108
- if submit_button and user_query:
109
- try:
110
- # Invoke the graph app
111
- output = app.invoke({"messages": [user_query]})
112
- final_response = output['messages'][-1]
113
-
114
- # Display Results
115
- st.subheader("Response")
116
- st.write(final_response)
117
- except Exception as e:
118
- st.error(f"An error occurred: {e}")
 
 
 
 
 
 
1
  import os
2
  import json
3
  import operator
4
+ import streamlit as st
5
  import tempfile
6
  from typing import TypedDict, Annotated, Sequence
 
 
7
  from langchain_openai import ChatOpenAI
 
8
  from langchain_core.tools import tool
9
  from langchain_core.utils.function_calling import convert_to_openai_tool
10
  from langgraph.graph import StateGraph, END
11
 
12
+ # Environment Setup
13
+ os.environ['OPENAI_API_KEY'] = os.getenv("OPENAI_API_KEY")
 
 
 
 
 
14
 
15
+ # Model Initialization
 
 
 
 
 
16
  model = ChatOpenAI(temperature=0)
17
 
 
18
  @tool
19
  def multiply(first_number: int, second_number: int):
20
+ """Multiplies two numbers together and returns the result."""
21
  return first_number * second_number
22
 
23
  model_with_tools = model.bind(tools=[convert_to_openai_tool(multiply)])
24
 
25
+ # State Setup
26
  class AgentState(TypedDict):
27
+ messages: Annotated[Sequence, operator.add]
28
 
29
  graph = StateGraph(AgentState)
30
 
31
  def invoke_model(state):
32
+ question = state['messages'][-1]
 
33
  return {"messages": [model_with_tools.invoke(question)]}
34
 
35
  graph.add_node("agent", invoke_model)
36
 
37
  def invoke_tool(state):
38
  tool_calls = state['messages'][-1].additional_kwargs.get("tool_calls", [])
 
 
39
  for tool_call in tool_calls:
40
  if tool_call.get("function").get("name") == "multiply":
41
+ res = multiply.invoke(json.loads(tool_call.get("function").get("arguments")))
42
+ return {"messages": [f"Tool Result: {res}"]}
43
+ return {"messages": ["No tool input provided."]}
 
 
 
 
 
 
44
 
45
  graph.add_node("tool", invoke_tool)
46
  graph.add_edge("tool", END)
47
  graph.set_entry_point("agent")
48
 
49
  def router(state):
50
+ calls = state['messages'][-1].additional_kwargs.get("tool_calls", [])
51
+ return "multiply" if calls else "end"
52
+
53
+ graph.add_conditional_edges("agent", router, {"multiply": "tool", "end": END})
54
+ app_graph = graph.compile()
55
+
56
+ # Save graph visualization as an image
57
+ with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmpfile:
58
+ graph_viz = app_graph.get_graph(xray=True)
59
+ tmpfile.write(graph_viz.draw_mermaid_png()) # Write binary image data to file
60
+ graph_image_path = tmpfile.name
61
+
62
+ # Streamlit Interface
63
+ st.title("Simple Tool Calling Demo")
64
+
65
+ # Display the workflow graph
66
+ st.image(graph_image_path, caption="Workflow Visualization")
67
+
68
+ tab1, tab2 = st.tabs(["Try Multiplication", "Ask General Queries"])
69
+
70
+ with tab1:
71
+ st.subheader("Try Multiplication")
72
+ col1, col2 = st.columns(2)
73
+
74
+ with col1:
75
+ first_number = st.number_input("First Number", value=0, step=1)
76
+ with col2:
77
+ second_number = st.number_input("Second Number", value=0, step=1)
78
+
79
+ if st.button("Multiply"):
80
+ question = f"What is {first_number} * {second_number}?"
81
+ output = app_graph.invoke({"messages": [question]})
82
+ st.success(output['messages'][-1])
83
+
84
+ with tab2:
85
+ st.subheader("General Query")
86
+ user_input = st.text_input("Enter your question here")
87
+
88
+ if st.button("Submit"):
89
+ if user_input:
90
+ try:
91
+ result = app_graph.invoke({"messages": [user_input]})
92
+ st.write("Response:")
93
+ st.success(result['messages'][-1])
94
+ except Exception as e:
95
+ st.error("Something went wrong. Try again!")
96
+ else:
97
+ st.warning("Please enter a valid input.")
98
+
99
+
100
+ st.sidebar.title("References")
101
+ st.sidebar.markdown("1. [LangGraph Tool Calling](https://github.com/aritrasen87/LLM_RAG_Model_Deployment/blob/main/LangGraph_02_ToolCalling.ipynb)")