vtony commited on
Commit
ec64ddf
·
verified ·
1 Parent(s): 600333b

Delete agent.py

Browse files
Files changed (1) hide show
  1. agent.py +0 -242
agent.py DELETED
@@ -1,242 +0,0 @@
1
- import os
2
- import time
3
- import json
4
- from dotenv import load_dotenv
5
- from langgraph.graph import StateGraph, END
6
- from langgraph.prebuilt import ToolNode, tools_condition
7
- from langchain_google_genai import ChatGoogleGenerativeAI
8
- from langchain_community.tools import DuckDuckGoSearchRun
9
- from langchain_community.document_loaders import WikipediaLoader, ArxivLoader
10
- from langchain_core.messages import SystemMessage, AIMessage, HumanMessage
11
- from langchain_core.tools import tool
12
- from tenacity import retry, stop_after_attempt, wait_exponential
13
-
14
- # Load environment variables
15
- load_dotenv()
16
- google_api_key = os.getenv("GOOGLE_API_KEY") or os.environ.get("GOOGLE_API_KEY")
17
- if not google_api_key:
18
- raise ValueError("Missing GOOGLE_API_KEY environment variable")
19
-
20
- # --- Math Tools ---
21
- @tool
22
- def multiply(a: int, b: int) -> int:
23
- """Multiply two integers."""
24
- return a * b
25
-
26
- @tool
27
- def add(a: int, b: int) -> int:
28
- """Add two integers."""
29
- return a + b
30
-
31
- @tool
32
- def subtract(a: int, b: int) -> int:
33
- """Subtract b from a."""
34
- return a - b
35
-
36
- @tool
37
- def divide(a: int, b: int) -> float:
38
- """Divide a by b, error on zero."""
39
- if b == 0:
40
- raise ValueError("Cannot divide by zero.")
41
- return a / b
42
-
43
- @tool
44
- def modulus(a: int, b: int) -> int:
45
- """Compute a mod b."""
46
- return a % b
47
-
48
- # --- Browser Tools ---
49
- @tool
50
- def wiki_search(query: str) -> str:
51
- """Search Wikipedia and return up to 3 relevant documents."""
52
- try:
53
- docs = WikipediaLoader(query=query, load_max_docs=3).load()
54
- if not docs:
55
- return "No Wikipedia results found."
56
-
57
- results = []
58
- for doc in docs:
59
- title = doc.metadata.get('title', 'Unknown Title')
60
- content = doc.page_content[:2000] # Limit content length
61
- results.append(f"Title: {title}\nContent: {content}")
62
-
63
- return "\n\n---\n\n".join(results)
64
- except Exception as e:
65
- return f"Wikipedia search error: {str(e)}"
66
-
67
- @tool
68
- def arxiv_search(query: str) -> str:
69
- """Search Arxiv and return up to 3 relevant papers."""
70
- try:
71
- docs = ArxivLoader(query=query, load_max_docs=3).load()
72
- if not docs:
73
- return "No arXiv papers found."
74
-
75
- results = []
76
- for doc in docs:
77
- title = doc.metadata.get('Title', 'Unknown Title')
78
- authors = ", ".join(doc.metadata.get('Authors', []))
79
- content = doc.page_content[:2000] # Limit content length
80
- results.append(f"Title: {title}\nAuthors: {authors}\nContent: {content}")
81
-
82
- return "\n\n---\n\n".join(results)
83
- except Exception as e:
84
- return f"arXiv search error: {str(e)}"
85
-
86
- @tool
87
- def web_search(query: str) -> str:
88
- """Search the web using DuckDuckGo and return top results."""
89
- try:
90
- search = DuckDuckGoSearchRun()
91
- result = search.run(query)
92
- return f"Web search results for '{query}':\n{result[:2000]}" # Limit content length
93
- except Exception as e:
94
- return f"Web search error: {str(e)}"
95
-
96
- # --- Load system prompt ---
97
- with open("system_prompt.txt", "r", encoding="utf-8") as f:
98
- system_prompt = f.read()
99
-
100
- # --- System message ---
101
- sys_msg = SystemMessage(content=system_prompt)
102
-
103
- # --- Tool Setup ---
104
- tools = [
105
- multiply,
106
- add,
107
- subtract,
108
- divide,
109
- modulus,
110
- wiki_search,
111
- arxiv_search,
112
- web_search,
113
- ]
114
-
115
- # --- Graph Builder ---
116
- def build_graph():
117
- # Initialize model with Gemini 2.5 Flash - the latest and best FREE model
118
- llm = ChatGoogleGenerativeAI(
119
- model="gemini-2.5-flash", # Corrected to the latest free model
120
- temperature=0.3,
121
- google_api_key=google_api_key,
122
- max_retries=3
123
- )
124
-
125
- # Bind tools to LLM
126
- llm_with_tools = llm.bind_tools(tools)
127
-
128
- # Define state with proper initialization
129
- class AgentState:
130
- def __init__(self, messages):
131
- self.messages = messages
132
-
133
- # Node definitions with error handling
134
- def agent_node(state):
135
- """Main agent node that processes messages with retry logic"""
136
- try:
137
- # Add rate limiting
138
- time.sleep(1) # 1 second delay between requests
139
-
140
- # Add retry logic for API quota issues
141
- @retry(stop=stop_after_attempt(3),
142
- wait=wait_exponential(multiplier=1, min=4, max=10))
143
- def invoke_llm_with_retry():
144
- return llm_with_tools.invoke(state.messages)
145
-
146
- response = invoke_llm_with_retry()
147
- return AgentState(messages=state.messages + [response])
148
-
149
- except Exception as e:
150
- # Handle specific errors
151
- error_type = "UNKNOWN"
152
- if "429" in str(e):
153
- error_type = "QUOTA_EXCEEDED"
154
- elif "400" in str(e):
155
- error_type = "INVALID_REQUEST"
156
-
157
- error_msg = f"AGENT ERROR ({error_type}): {str(e)[:200]}"
158
- return AgentState(messages=state.messages + [AIMessage(content=error_msg)])
159
-
160
- # Tool node
161
- def tool_node(state):
162
- """Execute tools based on agent's request"""
163
- last_message = state.messages[-1]
164
- tool_calls = last_message.additional_kwargs.get("tool_calls", [])
165
-
166
- tool_responses = []
167
- for tool_call in tool_calls:
168
- tool_name = tool_call["function"]["name"]
169
- tool_args = tool_call["function"].get("arguments", {})
170
-
171
- # Find the tool
172
- tool_func = next((t for t in tools if t.name == tool_name), None)
173
- if not tool_func:
174
- tool_responses.append(f"Tool {tool_name} not found")
175
- continue
176
-
177
- try:
178
- # Execute the tool
179
- if isinstance(tool_args, str):
180
- # Parse JSON if arguments are in string format
181
- tool_args = json.loads(tool_args)
182
-
183
- result = tool_func.invoke(tool_args)
184
- tool_responses.append(f"Tool {tool_name} result: {result}")
185
- except Exception as e:
186
- tool_responses.append(f"Tool {tool_name} error: {str(e)}")
187
-
188
- tool_response_content = "\n".join(tool_responses)
189
- return AgentState(messages=state.messages + [AIMessage(content=tool_response_content)])
190
-
191
- # Custom condition function
192
- def should_continue(state):
193
- last_message = state.messages[-1]
194
-
195
- # If there was an error, end
196
- if "AGENT ERROR" in last_message.content:
197
- return "end"
198
-
199
- # Check for tool calls
200
- if hasattr(last_message, "tool_calls") and last_message.tool_calls:
201
- return "tools"
202
-
203
- # Check for final answer
204
- if "FINAL ANSWER" in last_message.content:
205
- return "end"
206
-
207
- # Otherwise, continue to agent
208
- return "agent"
209
-
210
- # Build the graph
211
- workflow = StateGraph(AgentState)
212
-
213
- # Add nodes
214
- workflow.add_node("agent", agent_node)
215
- workflow.add_node("tools", tool_node)
216
-
217
- # Set entry point
218
- workflow.set_entry_point("agent")
219
-
220
- # Define edges
221
- workflow.add_conditional_edges(
222
- "agent",
223
- should_continue,
224
- {
225
- "agent": "agent",
226
- "tools": "tools",
227
- "end": END
228
- }
229
- )
230
-
231
- workflow.add_conditional_edges(
232
- "tools",
233
- lambda state: "agent",
234
- {
235
- "agent": "agent"
236
- }
237
- )
238
-
239
- return workflow.compile()
240
-
241
- # Initialize the agent graph
242
- agent_graph = build_graph()