Omachoko commited on
Commit
2d0e062
·
1 Parent(s): 50f18bd

GAIA agent: strict output normalization, reasoning planner, RAG, modular tool chaining, robust error handling

Browse files
Files changed (1) hide show
  1. app.py +85 -53
app.py CHANGED
@@ -4,6 +4,7 @@ import requests
4
  import inspect
5
  import pandas as pd
6
  from typing import Any
 
7
 
8
  # (Keep Constants as is)
9
  # --- Constants ---
@@ -281,13 +282,81 @@ Question:
281
  Answer:
282
  """
283
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
284
  # --- Refactored ModularGAIAAgent ---
285
  class ModularGAIAAgent:
286
- def __init__(self, api_url=DEFAULT_API_URL, tool_registry=None):
 
287
  self.api_url = api_url
288
- self.tools = tool_registry or TOOL_REGISTRY
289
  self.reasoning_trace = []
290
  self.file_cache = set(os.listdir('.'))
 
291
 
292
  def fetch_questions(self, from_api=True, questions_path="Hugging Face Questions"):
293
  """Fetch questions from API or local file."""
@@ -357,15 +426,15 @@ class ModularGAIAAgent:
357
  """Analyze file and return context for the question."""
358
  try:
359
  if file_type == 'audio':
360
- transcript = self.tools['asr_transcribe'](file_name)
361
  self.reasoning_trace.append(f"Transcribed audio: {transcript[:100]}...")
362
  return transcript
363
  elif file_type == 'image':
364
- caption = self.tools['image_caption'](file_name)
365
  self.reasoning_trace.append(f"Image caption: {caption}")
366
  return caption
367
  elif file_type == 'code':
368
- result = self.tools['code_analysis'](file_name)
369
  self.reasoning_trace.append(f"Code analysis result: {result}")
370
  return result
371
  elif file_type == 'excel':
@@ -400,41 +469,7 @@ class ModularGAIAAgent:
400
  self.reasoning_trace.append(f"Analyze file error: {e}")
401
  return None
402
 
403
- def smart_tool_select(self, question, file_type=None):
404
- """Select the best tool(s) for the question, optionally using GPT-4.1 for planning."""
405
- api_key = os.environ.get("OPENAI_API_KEY", "")
406
- try:
407
- if api_key:
408
- plan_prompt = f"""
409
- You are an expert AI agent. Given the following question and file type, suggest the best tool(s) to use from this list: {list(self.tools.keys())}.
410
- Question: {question}
411
- File type: {file_type}
412
- Respond with a comma-separated list of tool names only, in order of use. If unsure, start with web_search_duckduckgo.
413
- """
414
- plan = gpt4_chat(plan_prompt, api_key=api_key)
415
- tool_names = [t.strip() for t in plan.split(',') if t.strip() in self.tools]
416
- if tool_names:
417
- return tool_names
418
- except Exception as e:
419
- logger.error(f"smart_tool_select planning error: {e}")
420
- # Fallback: heuristic
421
- if file_type == 'audio':
422
- return ['asr_transcribe']
423
- elif file_type == 'image':
424
- return ['image_caption']
425
- elif file_type == 'code':
426
- return ['code_analysis']
427
- elif file_type in ['excel', 'csv']:
428
- return ['table_qa']
429
- elif 'youtube.com' in question or 'youtu.be' in question:
430
- return ['youtube_video_qa']
431
- elif any(w in question.lower() for w in ['wikipedia', 'who', 'when', 'where', 'what', 'how', 'find', 'search']):
432
- return ['web_search_duckduckgo']
433
- else:
434
- return ['llama3_chat']
435
-
436
  def answer_question(self, question_obj):
437
- """Answer a question using the best tool(s) and context."""
438
  self.reasoning_trace = []
439
  q = question_obj["question"]
440
  file_name = question_obj.get("file_name", "")
@@ -446,19 +481,23 @@ Respond with a comma-separated list of tool names only, in order of use. If unsu
446
  if local_file:
447
  file_type = self.detect_file_type(local_file)
448
  file_content = self.analyze_file(local_file, file_type)
449
- # Smart tool selection
450
- tool_names = self.smart_tool_select(q, file_type)
 
 
 
 
 
 
451
  answer = None
452
- context = file_content
453
  for tool_name in tool_names:
454
- tool = self.tools[tool_name]
455
  try:
456
  logger.info(f"Using tool: {tool_name} | Question: {q} | Context: {str(context)[:200]}")
457
  if tool_name == 'web_search_duckduckgo':
458
  context = tool(q)
459
  answer = llama3_chat(build_prompt(context, q))
460
- elif tool_name == 'gpt4_chat':
461
- answer = tool(build_prompt(context, q))
462
  elif tool_name == 'table_qa' and file_content:
463
  answer = tool(q, file_content)
464
  elif tool_name in ['asr_transcribe', 'image_caption', 'code_analysis'] and file_content:
@@ -466,7 +505,6 @@ Respond with a comma-separated list of tool names only, in order of use. If unsu
466
  elif tool_name == 'youtube_video_qa':
467
  answer = tool(q, q)
468
  else:
469
- # Always pass context if available
470
  if context:
471
  answer = llama3_chat(build_prompt(context, q))
472
  else:
@@ -479,13 +517,7 @@ Respond with a comma-separated list of tool names only, in order of use. If unsu
479
  continue
480
  self.reasoning_trace.append(f"Tools used: {tool_names}")
481
  self.reasoning_trace.append(f"Final answer: {answer}")
482
- return self.format_answer(answer), self.reasoning_trace
483
-
484
- def format_answer(self, answer):
485
- """Strict GAIA: only the answer, no extra text, no prefix."""
486
- if isinstance(answer, str):
487
- return answer.strip().split('\n')[0]
488
- return str(answer)
489
 
490
  # --- Basic Agent Definition (now wraps ModularGAIAAgent) ---
491
  class BasicAgent:
 
4
  import inspect
5
  import pandas as pd
6
  from typing import Any
7
+ import re
8
 
9
  # (Keep Constants as is)
10
  # --- Constants ---
 
282
  Answer:
283
  """
284
 
285
+ # --- Centralized Output Formatting & Normalization ---
286
+ def gaia_normalize_answer(answer):
287
+ """Normalize answer for GAIA: remove units, articles, extra text, and ensure concise, factual output."""
288
+ if not isinstance(answer, str):
289
+ answer = str(answer)
290
+ # Remove common articles and units unless required
291
+ answer = answer.strip()
292
+ answer = re.sub(r"\b(the|a|an)\b", "", answer, flags=re.IGNORECASE)
293
+ answer = re.sub(r"\s+", " ", answer)
294
+ # Remove currency, percent, or units unless specified (GAIA rules)
295
+ answer = re.sub(r"\$|%|USD|dollars|euros|eur|\bpercent\b", "", answer, flags=re.IGNORECASE)
296
+ # Remove leading/trailing punctuation
297
+ answer = answer.strip(' .,:;\n\t')
298
+ return answer
299
+
300
+ # --- Reasoning Planner for Tool Chaining ---
301
+ def reasoning_planner(question, file_type, tools):
302
+ """Plan the sequence of tools to use for a question. Uses LLM or heuristic."""
303
+ # Heuristic: if file_type is known, use the corresponding tool; else, use web search + LLM
304
+ if file_type == 'audio':
305
+ return ['asr_transcribe', 'llama3_chat']
306
+ elif file_type == 'image':
307
+ return ['image_caption', 'llama3_chat']
308
+ elif file_type == 'code':
309
+ return ['code_analysis', 'llama3_chat']
310
+ elif file_type in ['excel', 'csv']:
311
+ return ['table_qa']
312
+ elif 'youtube.com' in question or 'youtu.be' in question:
313
+ return ['youtube_video_qa']
314
+ elif any(w in question.lower() for w in ['wikipedia', 'who', 'when', 'where', 'what', 'how', 'find', 'search']):
315
+ return ['web_search_duckduckgo', 'llama3_chat']
316
+ else:
317
+ return ['llama3_chat']
318
+
319
+ # --- Improved RAG: Context Retrieval & Chunking ---
320
+ def retrieve_context(question, context_files, max_chunks=3):
321
+ """Retrieve relevant context chunks from large files for RAG."""
322
+ # Simple keyword search for now; can be replaced with semantic search
323
+ relevant_chunks = []
324
+ for file_path in context_files:
325
+ try:
326
+ with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
327
+ text = f.read()
328
+ # Split into chunks (e.g., 500 words)
329
+ chunks = [text[i:i+2000] for i in range(0, len(text), 2000)]
330
+ for chunk in chunks:
331
+ if any(word.lower() in chunk.lower() for word in question.split()):
332
+ relevant_chunks.append(chunk)
333
+ if len(relevant_chunks) >= max_chunks:
334
+ break
335
+ except Exception as e:
336
+ logger.error(f"retrieve_context error: {e}")
337
+ return '\n'.join(relevant_chunks)
338
+
339
+ # --- Modular Tool Registry & Chaining ---
340
+ class ToolRegistry:
341
+ """Central registry for tools. Allows easy addition and chaining."""
342
+ def __init__(self, tools):
343
+ self.tools = tools
344
+ def get(self, name):
345
+ return self.tools.get(name)
346
+ def add(self, name, func):
347
+ self.tools[name] = func
348
+ def list(self):
349
+ return list(self.tools.keys())
350
+
351
  # --- Refactored ModularGAIAAgent ---
352
  class ModularGAIAAgent:
353
+ """GAIA-compliant agent with robust reasoning, tool chaining, RAG, and output normalization."""
354
+ def __init__(self, api_url=DEFAULT_API_URL, tool_registry=None, context_files=None):
355
  self.api_url = api_url
356
+ self.tools = ToolRegistry(tool_registry or TOOL_REGISTRY)
357
  self.reasoning_trace = []
358
  self.file_cache = set(os.listdir('.'))
359
+ self.context_files = context_files or []
360
 
361
  def fetch_questions(self, from_api=True, questions_path="Hugging Face Questions"):
362
  """Fetch questions from API or local file."""
 
426
  """Analyze file and return context for the question."""
427
  try:
428
  if file_type == 'audio':
429
+ transcript = self.tools.get('asr_transcribe')(file_name)
430
  self.reasoning_trace.append(f"Transcribed audio: {transcript[:100]}...")
431
  return transcript
432
  elif file_type == 'image':
433
+ caption = self.tools.get('image_caption')(file_name)
434
  self.reasoning_trace.append(f"Image caption: {caption}")
435
  return caption
436
  elif file_type == 'code':
437
+ result = self.tools.get('code_analysis')(file_name)
438
  self.reasoning_trace.append(f"Code analysis result: {result}")
439
  return result
440
  elif file_type == 'excel':
 
469
  self.reasoning_trace.append(f"Analyze file error: {e}")
470
  return None
471
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
472
  def answer_question(self, question_obj):
 
473
  self.reasoning_trace = []
474
  q = question_obj["question"]
475
  file_name = question_obj.get("file_name", "")
 
481
  if local_file:
482
  file_type = self.detect_file_type(local_file)
483
  file_content = self.analyze_file(local_file, file_type)
484
+ # RAG: retrieve context if needed
485
+ rag_context = ''
486
+ if not file_content and self.context_files:
487
+ rag_context = retrieve_context(q, self.context_files)
488
+ if rag_context:
489
+ self.reasoning_trace.append(f"RAG context used: {rag_context[:200]}...")
490
+ # Reasoning planner: decide tool chain
491
+ tool_names = reasoning_planner(q, file_type, self.tools.list())
492
  answer = None
493
+ context = file_content or rag_context
494
  for tool_name in tool_names:
495
+ tool = self.tools.get(tool_name)
496
  try:
497
  logger.info(f"Using tool: {tool_name} | Question: {q} | Context: {str(context)[:200]}")
498
  if tool_name == 'web_search_duckduckgo':
499
  context = tool(q)
500
  answer = llama3_chat(build_prompt(context, q))
 
 
501
  elif tool_name == 'table_qa' and file_content:
502
  answer = tool(q, file_content)
503
  elif tool_name in ['asr_transcribe', 'image_caption', 'code_analysis'] and file_content:
 
505
  elif tool_name == 'youtube_video_qa':
506
  answer = tool(q, q)
507
  else:
 
508
  if context:
509
  answer = llama3_chat(build_prompt(context, q))
510
  else:
 
517
  continue
518
  self.reasoning_trace.append(f"Tools used: {tool_names}")
519
  self.reasoning_trace.append(f"Final answer: {answer}")
520
+ return gaia_normalize_answer(answer), self.reasoning_trace
 
 
 
 
 
 
521
 
522
  # --- Basic Agent Definition (now wraps ModularGAIAAgent) ---
523
  class BasicAgent: