Omachoko commited on
Commit
008f512
·
1 Parent(s): 10285e9

Robustify agent: better context passing, error handling, logging, prompt engineering, and dependencies

Browse files
Files changed (2) hide show
  1. app.py +181 -107
  2. requirements.txt +4 -1
app.py CHANGED
@@ -1,4 +1,4 @@
1
- import os
2
  import gradio as gr
3
  import requests
4
  import inspect
@@ -25,6 +25,7 @@ import cv2
25
  import torch
26
  from bs4 import BeautifulSoup
27
  import openai
 
28
 
29
  logging.basicConfig(filename='gaia_agent.log', level=logging.INFO, format='%(asctime)s %(levelname)s:%(message)s')
30
  logger = logging.getLogger(__name__)
@@ -38,7 +39,7 @@ def llama3_chat(prompt):
38
  messages=[{"role": "user", "content": prompt}],
39
  )
40
  return completion.choices[0].message.content
41
- except Exception as e:
42
  logging.error(f"llama3_chat error: {e}")
43
  return f"LLM error: {e}"
44
 
@@ -243,116 +244,179 @@ TOOL_REGISTRY = {
243
  "gpt4_chat": gpt4_chat,
244
  }
245
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
246
  class ModularGAIAAgent:
247
- def __init__(self, api_url=DEFAULT_API_URL, tool_registry=TOOL_REGISTRY):
248
  self.api_url = api_url
249
- self.tools = tool_registry
250
  self.reasoning_trace = []
251
  self.file_cache = set(os.listdir('.'))
252
 
253
  def fetch_questions(self, from_api=True, questions_path="Hugging Face Questions"):
254
- if from_api:
255
- r = requests.get(f"{self.api_url}/questions")
256
- r.raise_for_status()
257
- return r.json()
258
- else:
259
- with open(questions_path) as f:
260
- data = f.read()
261
- start = data.find("[")
262
- end = data.rfind("]") + 1
263
- questions = json.loads(data[start:end])
264
- return questions
 
 
 
 
 
265
 
266
  def download_file(self, file_id, file_name=None):
267
- if not file_name:
268
- file_name = file_id
269
- if file_name in self.file_cache:
270
- return file_name
271
- url = f"{self.api_url}/files/{file_id}"
272
- r = requests.get(url)
273
- if r.status_code == 200:
274
- with open(file_name, "wb") as f:
275
- f.write(r.content)
276
- self.file_cache.add(file_name)
277
- return file_name
278
- else:
279
- self.reasoning_trace.append(f"Failed to download file {file_id} (status {r.status_code})")
 
 
 
 
 
 
 
280
  return None
281
 
282
  def detect_file_type(self, file_name):
283
- ext = os.path.splitext(file_name)[-1].lower()
284
- if ext in ['.mp3', '.wav', '.flac']:
285
- return 'audio'
286
- elif ext in ['.png', '.jpg', '.jpeg', '.bmp']:
287
- return 'image'
288
- elif ext in ['.py']:
289
- return 'code'
290
- elif ext in ['.xlsx']:
291
- return 'excel'
292
- elif ext in ['.csv']:
293
- return 'csv'
294
- elif ext in ['.json']:
295
- return 'json'
296
- elif ext in ['.txt', '.md']:
297
- return 'text'
298
- else:
299
- return 'unknown'
 
 
 
 
300
 
301
  def analyze_file(self, file_name, file_type):
302
- if file_type == 'audio':
303
- transcript = self.tools['asr_transcribe'](file_name)
304
- self.reasoning_trace.append(f"Transcribed audio: {transcript[:100]}...")
305
- return transcript
306
- elif file_type == 'image':
307
- caption = self.tools['image_caption'](file_name)
308
- self.reasoning_trace.append(f"Image caption: {caption}")
309
- return caption
310
- elif file_type == 'code':
311
- result = self.tools['code_analysis'](file_name)
312
- self.reasoning_trace.append(f"Code analysis result: {result}")
313
- return result
314
- elif file_type == 'excel':
315
- wb = openpyxl.load_workbook(file_name)
316
- ws = wb.active
317
- data = list(ws.values)
318
- headers = data[0]
319
- table = [dict(zip(headers, row)) for row in data[1:]]
320
- self.reasoning_trace.append(f"Excel table loaded: {table[:2]}...")
321
- return table
322
- elif file_type == 'csv':
323
- df = pd.read_csv(file_name)
324
- table = df.to_dict(orient='records')
325
- self.reasoning_trace.append(f"CSV table loaded: {table[:2]}...")
326
- return table
327
- elif file_type == 'json':
328
- with open(file_name) as f:
329
- data = json.load(f)
330
- self.reasoning_trace.append(f"JSON loaded: {str(data)[:100]}...")
331
- return data
332
- elif file_type == 'text':
333
- with open(file_name) as f:
334
- text = f.read()
335
- self.reasoning_trace.append(f"Text loaded: {text[:100]}...")
336
- return text
337
- else:
338
- self.reasoning_trace.append(f"Unknown file type: {file_name}")
 
 
 
 
 
 
 
339
  return None
340
 
341
  def smart_tool_select(self, question, file_type=None):
342
  """Select the best tool(s) for the question, optionally using GPT-4.1 for planning."""
343
- # Use GPT-4.1 to suggest a tool if available
344
  api_key = os.environ.get("OPENAI_API_KEY", "")
345
- if api_key:
346
- plan_prompt = f"""
 
347
  You are an expert AI agent. Given the following question and file type, suggest the best tool(s) to use from this list: {list(self.tools.keys())}.
348
  Question: {question}
349
  File type: {file_type}
350
  Respond with a comma-separated list of tool names only, in order of use. If unsure, start with web_search_duckduckgo.
351
  """
352
- plan = gpt4_chat(plan_prompt, api_key=api_key)
353
- tool_names = [t.strip() for t in plan.split(',') if t.strip() in self.tools]
354
- if tool_names:
355
- return tool_names
 
 
356
  # Fallback: heuristic
357
  if file_type == 'audio':
358
  return ['asr_transcribe']
@@ -370,6 +434,7 @@ Respond with a comma-separated list of tool names only, in order of use. If unsu
370
  return ['llama3_chat']
371
 
372
  def answer_question(self, question_obj):
 
373
  self.reasoning_trace = []
374
  q = question_obj["question"]
375
  file_name = question_obj.get("file_name", "")
@@ -384,31 +449,40 @@ Respond with a comma-separated list of tool names only, in order of use. If unsu
384
  # Smart tool selection
385
  tool_names = self.smart_tool_select(q, file_type)
386
  answer = None
387
- context = None
388
  for tool_name in tool_names:
389
  tool = self.tools[tool_name]
390
- if tool_name == 'web_search_duckduckgo':
391
- context = tool(q)
392
- # Use LLM to synthesize answer from snippets
393
- answer = llama3_chat(f"Answer the following question using ONLY the information below.\nQuestion: {q}\nSnippets:\n{context}\nAnswer:")
394
- elif tool_name == 'gpt4_chat':
395
- answer = tool(q)
396
- elif tool_name == 'table_qa' and file_content:
397
- answer = tool(q, file_content)
398
- elif tool_name in ['asr_transcribe', 'image_caption', 'code_analysis'] and file_content:
399
- answer = tool(file_name)
400
- elif tool_name == 'youtube_video_qa':
401
- answer = tool(q, q)
402
- else:
403
- answer = tool(q)
404
- if answer:
405
- break
 
 
 
 
 
 
 
 
 
406
  self.reasoning_trace.append(f"Tools used: {tool_names}")
407
  self.reasoning_trace.append(f"Final answer: {answer}")
408
  return self.format_answer(answer), self.reasoning_trace
409
 
410
  def format_answer(self, answer):
411
- # Strict GAIA: only the answer, no extra text, no prefix
412
  if isinstance(answer, str):
413
  return answer.strip().split('\n')[0]
414
  return str(answer)
 
1
+ chess - screenshote - screenshote - screenshote - screenshote - screenshotimport os
2
  import gradio as gr
3
  import requests
4
  import inspect
 
25
  import torch
26
  from bs4 import BeautifulSoup
27
  import openai
28
+ import magic # for robust file type detection
29
 
30
  logging.basicConfig(filename='gaia_agent.log', level=logging.INFO, format='%(asctime)s %(levelname)s:%(message)s')
31
  logger = logging.getLogger(__name__)
 
39
  messages=[{"role": "user", "content": prompt}],
40
  )
41
  return completion.choices[0].message.content
42
+ except Exception as e:
43
  logging.error(f"llama3_chat error: {e}")
44
  return f"LLM error: {e}"
45
 
 
244
  "gpt4_chat": gpt4_chat,
245
  }
246
 
247
+ # --- Utility: Robust file type detection ---
248
+ def detect_file_type_magic(file_name):
249
+ try:
250
+ mime = magic.Magic(mime=True)
251
+ filetype = mime.from_file(file_name)
252
+ if 'audio' in filetype:
253
+ return 'audio'
254
+ elif 'image' in filetype:
255
+ return 'image'
256
+ elif 'python' in filetype or file_name.endswith('.py'):
257
+ return 'code'
258
+ elif 'spreadsheet' in filetype or file_name.endswith('.xlsx'):
259
+ return 'excel'
260
+ elif 'csv' in filetype or file_name.endswith('.csv'):
261
+ return 'csv'
262
+ elif 'json' in filetype or file_name.endswith('.json'):
263
+ return 'json'
264
+ elif 'text' in filetype or file_name.endswith(('.txt', '.md')):
265
+ return 'text'
266
+ else:
267
+ return 'unknown'
268
+ except Exception as e:
269
+ logger.error(f"magic file type detection error: {e}")
270
+ return 'unknown'
271
+
272
+ # --- Improved prompt template for LLMs ---
273
+ def build_prompt(context, question):
274
+ return f"""
275
+ Context:
276
+ {context}
277
+
278
+ Question:
279
+ {question}
280
+
281
+ Answer:
282
+ """
283
+
284
+ # --- Refactored ModularGAIAAgent ---
285
  class ModularGAIAAgent:
286
+ def __init__(self, api_url=DEFAULT_API_URL, tool_registry=None):
287
  self.api_url = api_url
288
+ self.tools = tool_registry or TOOL_REGISTRY
289
  self.reasoning_trace = []
290
  self.file_cache = set(os.listdir('.'))
291
 
292
  def fetch_questions(self, from_api=True, questions_path="Hugging Face Questions"):
293
+ """Fetch questions from API or local file."""
294
+ try:
295
+ if from_api:
296
+ r = requests.get(f"{self.api_url}/questions")
297
+ r.raise_for_status()
298
+ return r.json()
299
+ else:
300
+ with open(questions_path) as f:
301
+ data = f.read()
302
+ start = data.find("[")
303
+ end = data.rfind("]") + 1
304
+ questions = json.loads(data[start:end])
305
+ return questions
306
+ except Exception as e:
307
+ logger.error(f"fetch_questions error: {e}")
308
+ return []
309
 
310
  def download_file(self, file_id, file_name=None):
311
+ """Download file if not present locally."""
312
+ try:
313
+ if not file_name:
314
+ file_name = file_id
315
+ if file_name in self.file_cache:
316
+ return file_name
317
+ url = f"{self.api_url}/files/{file_id}"
318
+ r = requests.get(url)
319
+ if r.status_code == 200:
320
+ with open(file_name, "wb") as f:
321
+ f.write(r.content)
322
+ self.file_cache.add(file_name)
323
+ return file_name
324
+ else:
325
+ self.reasoning_trace.append(f"Failed to download file {file_id} (status {r.status_code})")
326
+ logger.error(f"Failed to download file {file_id} (status {r.status_code})")
327
+ return None
328
+ except Exception as e:
329
+ logger.error(f"download_file error: {e}")
330
+ self.reasoning_trace.append(f"Download error: {e}")
331
  return None
332
 
333
  def detect_file_type(self, file_name):
334
+ """Detect file type using magic and extension as fallback."""
335
+ file_type = detect_file_type_magic(file_name)
336
+ if file_type == 'unknown':
337
+ ext = os.path.splitext(file_name)[-1].lower()
338
+ if ext in ['.mp3', '.wav', '.flac']:
339
+ return 'audio'
340
+ elif ext in ['.png', '.jpg', '.jpeg', '.bmp']:
341
+ return 'image'
342
+ elif ext in ['.py']:
343
+ return 'code'
344
+ elif ext in ['.xlsx']:
345
+ return 'excel'
346
+ elif ext in ['.csv']:
347
+ return 'csv'
348
+ elif ext in ['.json']:
349
+ return 'json'
350
+ elif ext in ['.txt', '.md']:
351
+ return 'text'
352
+ else:
353
+ return 'unknown'
354
+ return file_type
355
 
356
  def analyze_file(self, file_name, file_type):
357
+ """Analyze file and return context for the question."""
358
+ try:
359
+ if file_type == 'audio':
360
+ transcript = self.tools['asr_transcribe'](file_name)
361
+ self.reasoning_trace.append(f"Transcribed audio: {transcript[:100]}...")
362
+ return transcript
363
+ elif file_type == 'image':
364
+ caption = self.tools['image_caption'](file_name)
365
+ self.reasoning_trace.append(f"Image caption: {caption}")
366
+ return caption
367
+ elif file_type == 'code':
368
+ result = self.tools['code_analysis'](file_name)
369
+ self.reasoning_trace.append(f"Code analysis result: {result}")
370
+ return result
371
+ elif file_type == 'excel':
372
+ wb = openpyxl.load_workbook(file_name)
373
+ ws = wb.active
374
+ data = list(ws.values)
375
+ headers = data[0]
376
+ table = [dict(zip(headers, row)) for row in data[1:]]
377
+ self.reasoning_trace.append(f"Excel table loaded: {table[:2]}...")
378
+ return table
379
+ elif file_type == 'csv':
380
+ df = pd.read_csv(file_name)
381
+ table = df.to_dict(orient='records')
382
+ self.reasoning_trace.append(f"CSV table loaded: {table[:2]}...")
383
+ return table
384
+ elif file_type == 'json':
385
+ with open(file_name) as f:
386
+ data = json.load(f)
387
+ self.reasoning_trace.append(f"JSON loaded: {str(data)[:100]}...")
388
+ return data
389
+ elif file_type == 'text':
390
+ with open(file_name) as f:
391
+ text = f.read()
392
+ self.reasoning_trace.append(f"Text loaded: {text[:100]}...")
393
+ return text
394
+ else:
395
+ self.reasoning_trace.append(f"Unknown file type: {file_name}")
396
+ logger.warning(f"Unknown file type: {file_name}")
397
+ return None
398
+ except Exception as e:
399
+ logger.error(f"analyze_file error: {e}")
400
+ self.reasoning_trace.append(f"Analyze file error: {e}")
401
  return None
402
 
403
  def smart_tool_select(self, question, file_type=None):
404
  """Select the best tool(s) for the question, optionally using GPT-4.1 for planning."""
 
405
  api_key = os.environ.get("OPENAI_API_KEY", "")
406
+ try:
407
+ if api_key:
408
+ plan_prompt = f"""
409
  You are an expert AI agent. Given the following question and file type, suggest the best tool(s) to use from this list: {list(self.tools.keys())}.
410
  Question: {question}
411
  File type: {file_type}
412
  Respond with a comma-separated list of tool names only, in order of use. If unsure, start with web_search_duckduckgo.
413
  """
414
+ plan = gpt4_chat(plan_prompt, api_key=api_key)
415
+ tool_names = [t.strip() for t in plan.split(',') if t.strip() in self.tools]
416
+ if tool_names:
417
+ return tool_names
418
+ except Exception as e:
419
+ logger.error(f"smart_tool_select planning error: {e}")
420
  # Fallback: heuristic
421
  if file_type == 'audio':
422
  return ['asr_transcribe']
 
434
  return ['llama3_chat']
435
 
436
  def answer_question(self, question_obj):
437
+ """Answer a question using the best tool(s) and context."""
438
  self.reasoning_trace = []
439
  q = question_obj["question"]
440
  file_name = question_obj.get("file_name", "")
 
449
  # Smart tool selection
450
  tool_names = self.smart_tool_select(q, file_type)
451
  answer = None
452
+ context = file_content
453
  for tool_name in tool_names:
454
  tool = self.tools[tool_name]
455
+ try:
456
+ logger.info(f"Using tool: {tool_name} | Question: {q} | Context: {str(context)[:200]}")
457
+ if tool_name == 'web_search_duckduckgo':
458
+ context = tool(q)
459
+ answer = llama3_chat(build_prompt(context, q))
460
+ elif tool_name == 'gpt4_chat':
461
+ answer = tool(build_prompt(context, q))
462
+ elif tool_name == 'table_qa' and file_content:
463
+ answer = tool(q, file_content)
464
+ elif tool_name in ['asr_transcribe', 'image_caption', 'code_analysis'] and file_content:
465
+ answer = tool(file_name)
466
+ elif tool_name == 'youtube_video_qa':
467
+ answer = tool(q, q)
468
+ else:
469
+ # Always pass context if available
470
+ if context:
471
+ answer = llama3_chat(build_prompt(context, q))
472
+ else:
473
+ answer = tool(q)
474
+ if answer:
475
+ break
476
+ except Exception as e:
477
+ logger.error(f"Tool {tool_name} error: {e}")
478
+ self.reasoning_trace.append(f"Tool {tool_name} error: {e}")
479
+ continue
480
  self.reasoning_trace.append(f"Tools used: {tool_names}")
481
  self.reasoning_trace.append(f"Final answer: {answer}")
482
  return self.format_answer(answer), self.reasoning_trace
483
 
484
  def format_answer(self, answer):
485
+ """Strict GAIA: only the answer, no extra text, no prefix."""
486
  if isinstance(answer, str):
487
  return answer.strip().split('\n')[0]
488
  return str(answer)
requirements.txt CHANGED
@@ -11,4 +11,7 @@ opencv-python
11
  beautifulsoup4
12
  yt-dlp
13
  ultralytics
14
- openai
 
 
 
 
11
  beautifulsoup4
12
  yt-dlp
13
  ultralytics
14
+ openai
15
+ torchaudio
16
+ ffmpeg-python
17
+ python-magic