keenthinker commited on
Commit
966e533
·
verified ·
1 Parent(s): f6bfa8e

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +106 -220
agent.py CHANGED
@@ -1,223 +1,109 @@
1
- """LangGraph Agent"""
 
2
  import os
3
- from dotenv import load_dotenv
4
- from langgraph.graph import START, StateGraph, MessagesState
5
- from langgraph.prebuilt import tools_condition
6
- from langgraph.prebuilt import ToolNode
7
- from langchain_google_genai import ChatGoogleGenerativeAI
8
- from langchain_groq import ChatGroq
9
- from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
10
- from langchain_community.tools.tavily_search import TavilySearchResults
11
- from langchain_community.document_loaders import WikipediaLoader
12
- from langchain_community.document_loaders import ArxivLoader
13
- from langchain_community.vectorstores import SupabaseVectorStore
14
- from langchain_core.messages import SystemMessage, HumanMessage
15
- from langchain_core.tools import tool
16
- from langchain.tools.retriever import create_retriever_tool
17
- from supabase.client import Client, create_client
18
-
19
- load_dotenv()
20
-
21
- @tool
22
- def multiply(a: int, b: int) -> int:
23
- """Multiply two numbers.
24
- Args:
25
- a: first int
26
- b: second int
27
- """
28
- return a * b
29
-
30
- @tool
31
- def add(a: int, b: int) -> int:
32
- """Add two numbers.
33
-
34
- Args:
35
- a: first int
36
- b: second int
37
- """
38
- return a + b
39
-
40
- @tool
41
- def subtract(a: int, b: int) -> int:
42
- """Subtract two numbers.
43
-
44
- Args:
45
- a: first int
46
- b: second int
47
- """
48
- return a - b
49
-
50
- @tool
51
- def divide(a: int, b: int) -> int:
52
- """Divide two numbers.
53
-
54
- Args:
55
- a: first int
56
- b: second int
57
- """
58
- if b == 0:
59
- raise ValueError("Cannot divide by zero.")
60
- return a / b
61
-
62
- @tool
63
- def modulus(a: int, b: int) -> int:
64
- """Get the modulus of two numbers.
65
-
66
- Args:
67
- a: first int
68
- b: second int
69
- """
70
- return a % b
71
-
72
- @tool
73
- def wiki_search(query: str) -> str:
74
- """Search Wikipedia for a query and return maximum 2 results.
75
-
76
- Args:
77
- query: The search query."""
78
- search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
79
- formatted_search_docs = "\n\n---\n\n".join(
80
- [
81
- f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
82
- for doc in search_docs
83
- ])
84
- return {"wiki_results": formatted_search_docs}
85
-
86
- @tool
87
- def web_search(query: str) -> str:
88
- """Search Tavily for a query and return maximum 3 results.
89
-
90
- Args:
91
- query: The search query."""
92
- search_docs = TavilySearchResults(max_results=3).invoke(query=query)
93
- formatted_search_docs = "\n\n---\n\n".join(
94
- [
95
- f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
96
- for doc in search_docs
97
- ])
98
- return {"web_results": formatted_search_docs}
99
-
100
- @tool
101
- def arvix_search(query: str) -> str:
102
- """Search Arxiv for a query and return maximum 3 result.
103
 
104
- Args:
105
- query: The search query."""
106
- search_docs = ArxivLoader(query=query, load_max_docs=3).load()
107
- formatted_search_docs = "\n\n---\n\n".join(
108
- [
109
- f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
110
- for doc in search_docs
111
- ])
112
- return {"arvix_results": formatted_search_docs}
113
-
114
-
115
-
116
- # load the system prompt from the file
117
- with open("system_prompt.txt", "r", encoding="utf-8") as f:
118
- system_prompt = f.read()
119
-
120
- # System message
121
- sys_msg = SystemMessage(content=system_prompt)
122
-
123
- # build a retriever
124
- embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") # dim=768
125
- supabase: Client = create_client(
126
- os.environ.get("SUPABASE_URL"),
127
- os.environ.get("SUPABASE_SERVICE_KEY"))
128
- vector_store = SupabaseVectorStore(
129
- client=supabase,
130
- embedding= embeddings,
131
- table_name="documents",
132
- query_name="match_documents_langchain",
133
- )
134
- create_retriever_tool = create_retriever_tool(
135
- retriever=vector_store.as_retriever(),
136
- name="Question Search",
137
- description="A tool to retrieve similar questions from a vector store.",
138
- )
139
-
140
-
141
-
142
- tools = [
143
- multiply,
144
- add,
145
- subtract,
146
- divide,
147
- modulus,
148
- wiki_search,
149
- web_search,
150
- arvix_search,
151
- ]
152
-
153
- # Build graph function
154
- def build_graph(provider: str = "google"):
155
- """Build the graph"""
156
- # Load environment variables from .env file
157
- if provider == "google":
158
- # Google Gemini
159
- llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
160
- elif provider == "groq":
161
- # Groq https://console.groq.com/docs/models
162
- llm = ChatGroq(model="qwen-qwq-32b", temperature=0) # optional : qwen-qwq-32b gemma2-9b-it
163
- elif provider == "huggingface":
164
- # TODO: Add huggingface endpoint
165
- llm = ChatHuggingFace(
166
- llm=HuggingFaceEndpoint(
167
- url="https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf",
168
- temperature=0,
169
- ),
170
- )
171
- else:
172
- raise ValueError("Invalid provider. Choose 'google', 'groq' or 'huggingface'.")
173
- # Bind tools to LLM
174
- llm_with_tools = llm.bind_tools(tools)
175
-
176
- # Node
177
- def assistant(state: MessagesState):
178
- """Assistant node"""
179
- return {"messages": [llm_with_tools.invoke(state["messages"])]}
180
 
181
- # def retriever(state: MessagesState):
182
- # """Retriever node"""
183
- # similar_question = vector_store.similarity_search(state["messages"][0].content)
184
- #example_msg = HumanMessage(
185
- # content=f"Here I provide a similar question and answer for reference: \n\n{similar_question[0].page_content}",
186
- # )
187
- # return {"messages": [sys_msg] + state["messages"] + [example_msg]}
188
-
189
- from langchain_core.messages import AIMessage
190
-
191
- def retriever(state: MessagesState):
192
- query = state["messages"][-1].content
193
- similar_doc = vector_store.similarity_search(query, k=1)[0]
194
-
195
- content = similar_doc.page_content
196
- if "Final answer :" in content:
197
- answer = content.split("Final answer :")[-1].strip()
198
- else:
199
- answer = content.strip()
200
-
201
- return {"messages": [AIMessage(content=answer)]}
202
-
203
- # builder = StateGraph(MessagesState)
204
- #builder.add_node("retriever", retriever)
205
- #builder.add_node("assistant", assistant)
206
- #builder.add_node("tools", ToolNode(tools))
207
- #builder.add_edge(START, "retriever")
208
- #builder.add_edge("retriever", "assistant")
209
- #builder.add_conditional_edges(
210
- # "assistant",
211
- # tools_condition,
212
- #)
213
- #builder.add_edge("tools", "assistant")
214
-
215
- builder = StateGraph(MessagesState)
216
- builder.add_node("retriever", retriever)
217
-
218
- # Retriever ist Start und Endpunkt
219
- builder.set_entry_point("retriever")
220
- builder.set_finish_point("retriever")
221
-
222
- # Compile graph
223
- return builder.compile()
 
1
+ from smolagents import Tool, tool, CodeAgent, OpenAIServerModel, DuckDuckGoSearchTool
2
+ import time
3
  import os
4
+ import requests
5
+ import markdownify
6
+ #for the mp3 file reading
7
+ import whisper
8
+ import tempfile
9
+ import io
10
+
11
+
12
+ class Mod4Agent:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
+ def __init__(self):
15
+ self.api_key=os.getenv("OPENAI_KEY")
16
+
17
+ #base model
18
+ self.model = OpenAIServerModel(
19
+ model_id="gpt-4o",
20
+ api_base="https://api.openai.com/v1",
21
+ temperature=0.0,
22
+ api_key=self.api_key)
23
+
24
+ #base_prompt
25
+ self.base_prompt="""
26
+ You are an agent with a set of tools for answering to questions.
27
+ You need to be accurate and get the best possible answer in the simplest possible way.
28
+ You need to think step-by-step, and if at some point there is an error, backtrack and use a different method.
29
+ It is important to adhere to the instructions of the question as close as possible.
30
+ IMPORTANT: always answer according to the format required to the best of your abilities. Stating that you do not know, or explaining why, will give a score of 0 therefore it is to be avoided.
31
+ You can do it!
32
+
33
+ Question:
34
+ """
35
+
36
+
37
+ @tool
38
+ def audio_interpreter(input: bytes)->str:
39
+ """
40
+ Function to transcribe an mp3 file from raw bytes or file path into the corresponding text
41
+
42
+ Args:
43
+ input: raw bytes content of the input mp3 file, or its file path
44
+
45
+ Return:
46
+ str: a string with the text corresponding to the mp3 input file
47
+ """
48
+
49
+ model = whisper.load_model("tiny")
50
+
51
+ if isinstance(input, bytes):
52
+ with tempfile.NamedTemporaryFile(suffix=".mp3", delete=True) as tmp:
53
+ tmp.write(input)
54
+ tmp.flush()
55
+ result = model.transcribe(tmp.name)
56
+
57
+ elif isinstance(input, str) and os.path.exists(input):
58
+ # Safe if the HF environment mounts the file
59
+ result = model.transcribe(input)
60
+
61
+ else:
62
+ raise TypeError("Unsupported input type. Expected bytes or a valid file path.")
63
+
64
+ return result["text"]
65
+
66
+
67
+
68
+
69
+
70
+ self.list_tools=[DuckDuckGoSearchTool(), audio_interpreter]
71
+
72
+ self.agent = CodeAgent(tools=self.list_tools,
73
+ model=self.model,
74
+ additional_authorized_imports=['pandas','io', 'requests','markdownify'],
75
+ max_steps=10,
76
+ add_base_tools=True # Add any additional base tools
77
+ #planning_interval=3 # Enable planning every 3 steps) #-1 to suppress display of reasoning steps
78
+ )
79
+
80
+ print("BasicAgent initialized.")
81
+
 
 
 
 
 
 
 
 
82
 
83
+ #Retry policy if quota exceeded
84
+ def retry(self, prompt):
85
+ backoff = 20
86
+ while True:
87
+ try:
88
+ response = self.agent.run(prompt)
89
+ return response
90
+ break # Success
91
+ except Exception as e:
92
+ if "429" in str(e):
93
+ print(f"Rate limit hit. Sleeping for {backoff} seconds...")
94
+ time.sleep(backoff)
95
+ backoff = min(backoff * 2, 80) # max backoff = 80 seconds
96
+ else:
97
+ print("Error:", e)
98
+ break
99
+
100
+
101
+ def __call__(self, question: str) -> str:
102
+ print(f"Agent received question (first 50 chars): {question[:50]}...")
103
+ prompt=f'{self.base_prompt}\n {question}'
104
+ answer = self.retry(prompt)
105
+ print(f"Agent returning fixed answer: {answer}")
106
+ return answer
107
+
108
+
109
+