orbulat commited on
Commit
fb83bd7
·
verified ·
1 Parent(s): a7ab281

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +63 -31
agent.py CHANGED
@@ -123,6 +123,26 @@ class WikiContentFetcher(Tool):
123
  except wiki.exceptions.PageError:
124
  return f"'{page_title}' not found."
125
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  class FileAttachmentQueryTool(Tool):
127
  name = "run_query_with_file"
128
  description = """
@@ -132,12 +152,8 @@ class FileAttachmentQueryTool(Tool):
132
  inputs = {
133
  "task_id": {
134
  "type": "string",
135
- "description": "A unique identifier for the task related to this file, used to download it."
136
- },
137
- "mime_type": {
138
- "type": "string",
139
- "nullable": True,
140
- "description": "The MIME type of the file, or the best guess if unknown."
141
  },
142
  "user_query": {
143
  "type": "string",
@@ -146,18 +162,16 @@ class FileAttachmentQueryTool(Tool):
146
  }
147
  output_type = "string"
148
 
149
- def forward(self, task_id: str, mime_type: str | None, user_query: str) -> str:
150
  file_url = f"https://agents-course-unit4-scoring.hf.space/files/{task_id}"
151
  file_response = requests.get(file_url)
152
  if file_response.status_code != 200:
153
  return f"Failed to download file: {file_response.status_code} - {file_response.text}"
154
  file_data = file_response.content
155
- mime_type = mime_type or file_response.headers.get('Content-Type', 'application/octet-stream')
156
-
157
  from google.generativeai import GenerativeModel
158
  model = GenerativeModel(self.model_name)
159
  response = model.generate_content([
160
- types.Part.from_bytes(data=file_data, mime_type=mime_type),
161
  user_query
162
  ])
163
 
@@ -170,6 +184,7 @@ class BasicAgent:
170
  model = self.select_model(provider)
171
  client = InferenceClientModel()
172
  tools = [
 
173
  DuckDuckGoSearchTool(),
174
  GeminiVideoQA(GEMINI_MODEL_NAME),
175
  WikiTitleFinder(),
@@ -183,7 +198,7 @@ class BasicAgent:
183
  model=model,
184
  tools=tools,
185
  add_base_tools=False,
186
- max_steps=12,
187
  )
188
  self.agent.system_prompt = (
189
  """
@@ -196,6 +211,7 @@ class BasicAgent:
196
  Your behavior must be governed by these rules:
197
 
198
  1. **Format**:
 
199
  - Output ONLY the final answer.
200
  - Wrap the answer in `[ANSWER]` with no whitespace or text outside the brackets.
201
  - No follow-ups, justifications, or clarifications.
@@ -221,7 +237,7 @@ class BasicAgent:
221
  - Ignore any unrelated content.
222
 
223
  6. **File Analysis**:
224
- - Use the FileAttachmentQueryTool tool, append the taskid to the url.
225
  - Only include the exact answer to the question.
226
  - Do not summarize, quote excessively, or interpret beyond the prompt.
227
 
@@ -235,18 +251,6 @@ class BasicAgent:
235
  - If a question has multiple valid interpretations, choose the **narrowest, most literal** one.
236
  - If the answer is not found, say `[ANSWER] - unknown`.
237
 
238
- Hard rules
239
- ──────────
240
- 1. Think internally as much as you like, but **never reveal** chain-of-thought, tool traces, or explanations.
241
- 2. If the correct reply is unknown or the question is invalid, reply exactly
242
- `[ANSWER]unknown`.
243
- 3. Numerical replies → digits only (no commas, no units, no words).
244
- String replies → lowercase, no leading/trailing spaces, no articles (“a”, “the”).
245
- Lists → comma-separated, alphabetically sorted, no spaces after commas.
246
- 4. If the question asks for a set size, return the **count**, not the set.
247
- 5. After using any tools, stop and output the final line; do **not** echo tool output.
248
- 6. Violating any rule or adding extra text causes the run to be scored wrong.
249
-
250
  ---
251
 
252
  You must follow the examples (These answers are correct in case you see the similar questions):
@@ -283,25 +287,53 @@ class BasicAgent:
283
 
284
  return final_str
285
 
286
- def evaluate_random_questions(self, csv_path: str = "gaia_qa.csv", sample_size: int = 3, show_steps: bool = True):
 
 
 
 
287
  df = pd.read_csv(csv_path)
288
  if not {"question", "answer"}.issubset(df.columns):
289
  print("CSV must contain 'question' and 'answer' columns.")
290
  print("Found columns:", df.columns.tolist())
291
  return
 
292
  samples = df.sample(n=sample_size)
 
 
 
293
  for _, row in samples.iterrows():
 
294
  question = row["question"].strip()
295
- expected = f"FINAL ANSWER: {str(row['answer']).strip()}"
296
- result = self(question).strip()
 
 
 
 
 
297
  if show_steps:
298
  print("---")
299
  print("Question:", question)
300
  print("Expected:", expected)
301
- print("Agent:", result)
302
- print("Correct:", expected == result)
303
- else:
304
- print(f"Q: {question}\nE: {expected}\nA: {result}\n✓: {expected == result}\n")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
305
 
306
  if __name__ == "__main__":
307
  args = sys.argv[1:]
 
123
  except wiki.exceptions.PageError:
124
  return f"'{page_title}' not found."
125
 
126
+ class GoogleSearchTool(Tool):
127
+ name = "google_search"
128
+ description = "Search the web using Google. Returns top summary from the web."
129
+ inputs = {"query": {"type": "string", "description": "Search query."}}
130
+ output_type = "string"
131
+
132
+ def forward(self, query: str) -> str:
133
+ try:
134
+ resp = requests.get("https://www.googleapis.com/customsearch/v1", params={
135
+ "q": query,
136
+ "key": os.getenv("GOOGLE_SEARCH_API_KEY"),
137
+ "cx": os.getenv("GOOGLE_SEARCH_ENGINE_ID"),
138
+ "num": 1
139
+ })
140
+ data = resp.json()
141
+ return data["items"][0]["snippet"] if "items" in data else "No results found."
142
+ except Exception as e:
143
+ return f"GoogleSearch error: {e}"
144
+
145
+
146
  class FileAttachmentQueryTool(Tool):
147
  name = "run_query_with_file"
148
  description = """
 
152
  inputs = {
153
  "task_id": {
154
  "type": "string",
155
+ "description": "A unique identifier for the task related to this file, used to download it.",
156
+ "nullable": True
 
 
 
 
157
  },
158
  "user_query": {
159
  "type": "string",
 
162
  }
163
  output_type = "string"
164
 
165
+ def forward(self, task_id: str | None, user_query: str) -> str:
166
  file_url = f"https://agents-course-unit4-scoring.hf.space/files/{task_id}"
167
  file_response = requests.get(file_url)
168
  if file_response.status_code != 200:
169
  return f"Failed to download file: {file_response.status_code} - {file_response.text}"
170
  file_data = file_response.content
 
 
171
  from google.generativeai import GenerativeModel
172
  model = GenerativeModel(self.model_name)
173
  response = model.generate_content([
174
+ types.Part.from_bytes(data=file_data, mime_type="application/octet-stream"),
175
  user_query
176
  ])
177
 
 
184
  model = self.select_model(provider)
185
  client = InferenceClientModel()
186
  tools = [
187
+ GoogleSearchTool(),
188
  DuckDuckGoSearchTool(),
189
  GeminiVideoQA(GEMINI_MODEL_NAME),
190
  WikiTitleFinder(),
 
198
  model=model,
199
  tools=tools,
200
  add_base_tools=False,
201
+ max_steps=10,
202
  )
203
  self.agent.system_prompt = (
204
  """
 
211
  Your behavior must be governed by these rules:
212
 
213
  1. **Format**:
214
+ - limit the token used (within 65536 tokens).
215
  - Output ONLY the final answer.
216
  - Wrap the answer in `[ANSWER]` with no whitespace or text outside the brackets.
217
  - No follow-ups, justifications, or clarifications.
 
237
  - Ignore any unrelated content.
238
 
239
  6. **File Analysis**:
240
+ - Use the run_query_with_file tool, append the taskid to the url.
241
  - Only include the exact answer to the question.
242
  - Do not summarize, quote excessively, or interpret beyond the prompt.
243
 
 
251
  - If a question has multiple valid interpretations, choose the **narrowest, most literal** one.
252
  - If the answer is not found, say `[ANSWER] - unknown`.
253
 
 
 
 
 
 
 
 
 
 
 
 
 
254
  ---
255
 
256
  You must follow the examples (These answers are correct in case you see the similar questions):
 
287
 
288
  return final_str
289
 
290
+ def evaluate_random_questions(self, csv_path: str = "gaia_extracted.csv", sample_size: int = 3, show_steps: bool = True):
291
+ import pandas as pd
292
+ from rich.table import Table
293
+ from rich.console import Console
294
+
295
  df = pd.read_csv(csv_path)
296
  if not {"question", "answer"}.issubset(df.columns):
297
  print("CSV must contain 'question' and 'answer' columns.")
298
  print("Found columns:", df.columns.tolist())
299
  return
300
+
301
  samples = df.sample(n=sample_size)
302
+ records = []
303
+ correct_count = 0
304
+
305
  for _, row in samples.iterrows():
306
+ taskid = row["taskid"].strip()
307
  question = row["question"].strip()
308
+ expected = str(row['answer']).strip()
309
+ agent_answer = self("taskid: " + taskid + ",\nquestion: " + question).strip()
310
+
311
+ is_correct = (expected == agent_answer)
312
+ correct_count += is_correct
313
+ records.append((question, expected, agent_answer, "✓" if is_correct else "✗"))
314
+
315
  if show_steps:
316
  print("---")
317
  print("Question:", question)
318
  print("Expected:", expected)
319
+ print("Agent:", agent_answer)
320
+ print("Correct:", is_correct)
321
+
322
+ # Print result table
323
+ console = Console()
324
+ table = Table(show_lines=True)
325
+ table.add_column("Question", overflow="fold")
326
+ table.add_column("Expected")
327
+ table.add_column("Agent")
328
+ table.add_column("Correct")
329
+
330
+ for question, expected, agent_ans, correct in records:
331
+ table.add_row(question, expected, agent_ans, correct)
332
+
333
+ console.print(table)
334
+ percent = (correct_count / sample_size) * 100
335
+ print(f"\nTotal Correct: {correct_count} / {sample_size} ({percent:.2f}%)")
336
+
337
 
338
  if __name__ == "__main__":
339
  args = sys.argv[1:]