DrishtiSharma commited on
Commit
5472503
·
verified ·
1 Parent(s): 1d4c700

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +80 -63
app.py CHANGED
@@ -1,101 +1,118 @@
1
  import os
2
  import json
3
  import operator
4
- import streamlit as st
5
  import tempfile
6
  from typing import TypedDict, Annotated, Sequence
 
 
7
  from langchain_openai import ChatOpenAI
 
8
  from langchain_core.tools import tool
9
  from langchain_core.utils.function_calling import convert_to_openai_tool
10
  from langgraph.graph import StateGraph, END
11
 
12
- # Environment Setup
13
- os.environ['OPENAI_API_KEY'] = os.getenv("OPENAI_API_KEY")
 
 
 
 
 
14
 
15
- # Model Initialization
 
 
 
 
 
16
  model = ChatOpenAI(temperature=0)
17
 
 
18
  @tool
19
  def multiply(first_number: int, second_number: int):
20
- """Multiplies two numbers together and returns the result."""
21
  return first_number * second_number
22
 
23
  model_with_tools = model.bind(tools=[convert_to_openai_tool(multiply)])
24
 
25
- # State Setup
26
  class AgentState(TypedDict):
27
- messages: Annotated[Sequence, operator.add]
28
 
29
  graph = StateGraph(AgentState)
30
 
31
  def invoke_model(state):
32
- question = state['messages'][-1]
 
33
  return {"messages": [model_with_tools.invoke(question)]}
34
 
35
  graph.add_node("agent", invoke_model)
36
 
37
  def invoke_tool(state):
38
  tool_calls = state['messages'][-1].additional_kwargs.get("tool_calls", [])
 
 
39
  for tool_call in tool_calls:
40
  if tool_call.get("function").get("name") == "multiply":
41
- res = multiply.invoke(json.loads(tool_call.get("function").get("arguments")))
42
- return {"messages": [f"Tool Result: {res}"]}
43
- return {"messages": ["No tool input provided."]}
 
 
 
 
 
 
44
 
45
  graph.add_node("tool", invoke_tool)
46
  graph.add_edge("tool", END)
47
  graph.set_entry_point("agent")
48
 
49
  def router(state):
50
- calls = state['messages'][-1].additional_kwargs.get("tool_calls", [])
51
- return "multiply" if calls else "end"
52
-
53
- graph.add_conditional_edges("agent", router, {"multiply": "tool", "end": END})
54
- app_graph = graph.compile()
55
-
56
- # Save graph visualization as an image
57
- with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmpfile:
58
- graph_viz = app_graph.get_graph(xray=True)
59
- tmpfile.write(graph_viz.draw_mermaid_png()) # Write binary image data to file
60
- graph_image_path = tmpfile.name
61
-
62
- # Streamlit Interface
63
- st.title("Simple Tool Calling Demo")
64
-
65
- # Display the workflow graph
66
- st.image(graph_image_path, caption="Workflow Visualization")
67
-
68
- tab1, tab2 = st.tabs(["Try Multiplication", "Ask General Queries"])
69
-
70
- with tab1:
71
- st.subheader("Try Multiplication")
72
- col1, col2 = st.columns(2)
73
-
74
- with col1:
75
- first_number = st.number_input("First Number", value=0, step=1)
76
- with col2:
77
- second_number = st.number_input("Second Number", value=0, step=1)
78
-
79
- if st.button("Multiply"):
80
- question = f"What is {first_number} * {second_number}?"
81
- output = app_graph.invoke({"messages": [question]})
82
- st.success(output['messages'][-1])
83
-
84
- with tab2:
85
- st.subheader("General Query")
86
- user_input = st.text_input("Enter your question here")
87
-
88
- if st.button("Submit"):
89
- if user_input:
90
- try:
91
- result = app_graph.invoke({"messages": [user_input]})
92
- st.write("Response:")
93
- st.success(result['messages'][-1])
94
- except Exception as e:
95
- st.error("Something went wrong. Try again!")
96
- else:
97
- st.warning("Please enter a valid input.")
98
-
99
-
100
- st.sidebar.title("References")
101
- st.sidebar.markdown("1. [LangGraph Tool Calling](https://github.com/aritrasen87/LLM_RAG_Model_Deployment/blob/main/LangGraph_02_ToolCalling.ipynb)")
 
1
  import os
2
  import json
3
  import operator
 
4
  import tempfile
5
  from typing import TypedDict, Annotated, Sequence
6
+ from dotenv import load_dotenv
7
+
8
  from langchain_openai import ChatOpenAI
9
+ from langchain_core.messages import BaseMessage
10
  from langchain_core.tools import tool
11
  from langchain_core.utils.function_calling import convert_to_openai_tool
12
  from langgraph.graph import StateGraph, END
13
 
14
+ import streamlit as st
15
+
16
+ # ------------------- Streamlit UI Layout -------------------
17
+ #st.set_page_config(page_title="Streamlit LLM Graph", layout="wide")
18
+ st.title("Test App- Tool Calling and Conditional Graph")
19
+
20
+ # ------------------- Environment Setup -------------------
21
 
22
+ OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
23
+ if not OPENAI_API_KEY:
24
+ st.error("OpenAI API Key not found! Please set it in your environment variables.")
25
+ st.stop()
26
+
27
+ # ------------------- Model Initialization -------------------
28
  model = ChatOpenAI(temperature=0)
29
 
30
+ # ------------------- Tool Definition -------------------
31
  @tool
32
  def multiply(first_number: int, second_number: int):
33
+ """Multiplies two numbers together."""
34
  return first_number * second_number
35
 
36
  model_with_tools = model.bind(tools=[convert_to_openai_tool(multiply)])
37
 
38
+ # ------------------- State Graph Setup -------------------
39
  class AgentState(TypedDict):
40
+ messages: Annotated[Sequence[BaseMessage], operator.add]
41
 
42
  graph = StateGraph(AgentState)
43
 
44
  def invoke_model(state):
45
+ messages = state['messages']
46
+ question = messages[-1]
47
  return {"messages": [model_with_tools.invoke(question)]}
48
 
49
  graph.add_node("agent", invoke_model)
50
 
51
  def invoke_tool(state):
52
  tool_calls = state['messages'][-1].additional_kwargs.get("tool_calls", [])
53
+ multiply_call = None
54
+
55
  for tool_call in tool_calls:
56
  if tool_call.get("function").get("name") == "multiply":
57
+ multiply_call = tool_call
58
+
59
+ if multiply_call is None:
60
+ raise Exception("No multiply input found.")
61
+
62
+ res = multiply.invoke(
63
+ json.loads(multiply_call.get("function").get("arguments"))
64
+ )
65
+ return {"messages": [res]}
66
 
67
  graph.add_node("tool", invoke_tool)
68
  graph.add_edge("tool", END)
69
  graph.set_entry_point("agent")
70
 
71
  def router(state):
72
+ tool_calls = state['messages'][-1].additional_kwargs.get("tool_calls", [])
73
+ if len(tool_calls):
74
+ return "multiply"
75
+ else:
76
+ return "end"
77
+
78
+ graph.add_conditional_edges("agent", router, {
79
+ "multiply": "tool",
80
+ "end": END,
81
+ })
82
+
83
+ app = graph.compile()
84
+
85
+ # ------------------- Graph Visualization -------------------
86
+ st.header("Graph Visualization")
87
+
88
+ try:
89
+ # Generate graph visualization and save as image
90
+ with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmpfile:
91
+ graph_viz = app.get_graph(xray=True)
92
+ tmpfile.write(graph_viz.draw_mermaid_png()) # Write binary image data to file
93
+ graph_image_path = tmpfile.name
94
+
95
+ # Display the graph image
96
+ st.image(graph_image_path, caption="State Graph Visualization", use_column_width=True)
97
+ except Exception as e:
98
+ st.error(f"Error generating graph visualization: {e}")
99
+
100
+ # ------------------- Streamlit Interaction -------------------
101
+ st.header("Ask the LLM")
102
+
103
+ # User Input
104
+ user_query = st.text_input("Enter your query:", placeholder="E.g., What is 123 * 456?")
105
+ submit_button = st.button("Submit")
106
+
107
+ # Process Query
108
+ if submit_button and user_query:
109
+ try:
110
+ # Invoke the graph app
111
+ output = app.invoke({"messages": [user_query]})
112
+ final_response = output['messages'][-1]
113
+
114
+ # Display Results
115
+ st.subheader("Response")
116
+ st.write(final_response)
117
+ except Exception as e:
118
+ st.error(f"An error occurred: {e}")