DrishtiSharma commited on
Commit
693beb7
·
verified ·
1 Parent(s): 07aac68

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +195 -0
app.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import re
4
+ import base64
5
+ import streamlit as st
6
+ from io import BytesIO
7
+ from langchain_core.utils.function_calling import convert_to_openai_function
8
+ from langchain_core.messages import (
9
+ AIMessage,
10
+ BaseMessage,
11
+ ChatMessage,
12
+ FunctionMessage,
13
+ HumanMessage,
14
+ )
15
+ from langchain.tools.render import format_tool_to_openai_function
16
+ from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
17
+ from langgraph.graph import END, StateGraph
18
+ from langgraph.prebuilt.tool_executor import ToolExecutor, ToolInvocation
19
+ from langchain_core.tools import tool
20
+ from langchain_community.tools.tavily_search import TavilySearchResults
21
+ from langchain_experimental.utilities import PythonREPL
22
+ from langchain_openai import ChatOpenAI
23
+ from typing import Annotated, Sequence
24
+ from typing_extensions import TypedDict
25
+ import operator
26
+ import functools
27
+ import matplotlib.pyplot as plt
28
+
29
+ # Set up environment variables for API keys
30
+ TAVILY_API_KEY = os.getenv("TAVILY_API_KEY")
31
+ OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
32
+
33
+ # Validate API keys
34
+ if not TAVILY_API_KEY or not OPENAI_API_KEY:
35
+ st.error("API keys are missing. Please set TAVILY_API_KEY and OPENAI_API_KEY as secrets.")
36
+ st.stop()
37
+
38
+ # Define the AgentState class
39
+ class AgentState(TypedDict):
40
+ messages: Annotated[Sequence[BaseMessage], operator.add]
41
+ sender: str
42
+
43
+ # Initialize tools
44
+ tavily_tool = TavilySearchResults(max_results=5)
45
+ repl = PythonREPL()
46
+
47
+ @tool
48
+ def python_repl(code: Annotated[str, "The python code to execute to generate your chart."]):
49
+ """Executes Python code to generate a chart and returns the chart as a base64-encoded image."""
50
+ try:
51
+ # Execute the code
52
+ exec_globals = {"plt": plt}
53
+ exec_locals = {}
54
+ exec(code, exec_globals, exec_locals)
55
+
56
+ # Save the generated plot to a buffer
57
+ buf = BytesIO()
58
+ plt.savefig(buf, format="png")
59
+ buf.seek(0)
60
+
61
+ # Clear the plot to avoid overlapping
62
+ plt.clf()
63
+ plt.close()
64
+
65
+ # Encode image as base64
66
+ encoded_image = base64.b64encode(buf.getvalue()).decode("utf-8")
67
+ return {"status": "success", "image": encoded_image}
68
+ except Exception as e:
69
+ return {"status": "failed", "error": repr(e)}
70
+
71
+ tools = [tavily_tool, python_repl]
72
+
73
+ # Define a tool executor
74
+ tool_executor = ToolExecutor(tools)
75
+
76
+ # Define tool node
77
+ def tool_node(state):
78
+ """Executes tools in the graph."""
79
+ messages = state["messages"]
80
+ last_message = messages[-1]
81
+ tool_input = json.loads(last_message.additional_kwargs["function_call"]["arguments"])
82
+ if len(tool_input) == 1 and "__arg1" in tool_input:
83
+ tool_input = next(iter(tool_input.values()))
84
+ tool_name = last_message.additional_kwargs["function_call"]["name"]
85
+ action = ToolInvocation(tool=tool_name, tool_input=tool_input)
86
+ response = tool_executor.invoke(action)
87
+ if isinstance(response, dict) and response.get("status") == "success" and "image" in response:
88
+ return {
89
+ "messages": [
90
+ {
91
+ "role": "assistant",
92
+ "content": "Image generated successfully.",
93
+ "image": response["image"],
94
+ }
95
+ ]
96
+ }
97
+ else:
98
+ function_message = FunctionMessage(
99
+ content=f"{tool_name} response: {str(response)}", name=action.tool
100
+ )
101
+ return {"messages": [function_message]}
102
+
103
+ # Define router
104
+ def router(state):
105
+ """Determines the next step in the workflow."""
106
+ messages = state["messages"]
107
+ last_message = messages[-1]
108
+ if "function_call" in last_message.additional_kwargs:
109
+ return "call_tool"
110
+ if "FINAL ANSWER" in last_message.content:
111
+ return "end"
112
+ return "continue"
113
+
114
+ # Define agent creation function
115
+ def create_agent(llm, tools, system_message: str):
116
+ """Creates an agent."""
117
+ functions = [convert_to_openai_function(t) for t in tools]
118
+ prompt = ChatPromptTemplate.from_messages(
119
+ [
120
+ (
121
+ "system",
122
+ "You are a helpful AI assistant, collaborating with other assistants."
123
+ " Use the provided tools to progress towards answering the question."
124
+ " If you are unable to fully answer, that's OK, another assistant with different tools "
125
+ " will help where you left off. Execute what you can to make progress."
126
+ " If you or any of the other assistants have the final answer or deliverable,"
127
+ " prefix your response with FINAL ANSWER so the team knows to stop."
128
+ " You have access to the following tools: {tool_names}.\n{system_message}",
129
+ ),
130
+ MessagesPlaceholder(variable_name="messages"),
131
+ ]
132
+ )
133
+ prompt = prompt.partial(system_message=system_message)
134
+ prompt = prompt.partial(tool_names=", ".join([tool.name for tool in tools]))
135
+ return prompt | llm.bind_functions(functions)
136
+
137
+ # Define agent node
138
+ def agent_node(state, agent, name):
139
+ result = agent.invoke(state)
140
+ if isinstance(result, FunctionMessage):
141
+ pass
142
+ else:
143
+ # Sanitize the name field to match OpenAI's naming conventions
144
+ sanitized_name = re.sub(r"[^a-zA-Z0-9_-]", "_", name)
145
+ result = HumanMessage(**result.dict(exclude={"type", "name"}), name=sanitized_name)
146
+ return {"messages": [result], "sender": name}
147
+
148
+ # Initialize LLM
149
+ llm = ChatOpenAI(api_key=OPENAI_API_KEY)
150
+
151
+ # Create agents
152
+ research_agent = create_agent(
153
+ llm, [tavily_tool], system_message="You should provide accurate data for the chart generator to use."
154
+ )
155
+ chart_agent = create_agent(
156
+ llm, [python_repl], system_message="Any charts you display will be visible by the user."
157
+ )
158
+
159
+ # Define workflow graph
160
+ workflow = StateGraph(AgentState)
161
+ workflow.add_node("Researcher", functools.partial(agent_node, agent=research_agent, name="Researcher"))
162
+ workflow.add_node("Chart Generator", functools.partial(agent_node, agent=chart_agent, name="Chart Generator"))
163
+ workflow.add_node("call_tool", tool_node)
164
+ workflow.add_conditional_edges("Researcher", router, {"continue": "Chart Generator", "call_tool": "call_tool", "end": END})
165
+ workflow.add_conditional_edges("Chart Generator", router, {"continue": "Researcher", "call_tool": "call_tool", "end": END})
166
+ workflow.add_conditional_edges("call_tool", lambda x: x["sender"], {"Researcher": "Researcher", "Chart Generator": "Chart Generator"})
167
+ workflow.set_entry_point("Researcher")
168
+ graph = workflow.compile()
169
+
170
+ # Streamlit UI
171
+ st.title("Multi-Agent Workflow")
172
+ user_query = st.text_area("Enter your query:", "Fetch Malaysia's GDP over the past 5 years and draw a line graph.")
173
+ if st.button("Run Workflow"):
174
+ st.write("Running workflow...")
175
+ with st.spinner("Processing..."):
176
+ try:
177
+ messages = [HumanMessage(content=user_query)]
178
+ for step in graph.stream({"messages": messages}, {"recursion_limit": 150}):
179
+ st.write("Step Details:", step)
180
+ if "messages" in step:
181
+ for message in step["messages"]:
182
+ if "image" in message:
183
+ try:
184
+ # Decode the base64-encoded image
185
+ encoded_image = message["image"]
186
+ decoded_image = BytesIO(base64.b64decode(encoded_image))
187
+ # Display the image
188
+ st.image(decoded_image, caption="Generated Chart", use_column_width=True)
189
+ except Exception as e:
190
+ st.error(f"Failed to decode and display the image: {repr(e)}")
191
+ elif "content" in message:
192
+ # Display any text content
193
+ st.write(message["content"])
194
+ except Exception as e:
195
+ st.error(f"An error occurred: {e}")