Spaces:
Runtime error
Runtime error
Omachoko
commited on
Commit
·
2d0e062
1
Parent(s):
50f18bd
GAIA agent: strict output normalization, reasoning planner, RAG, modular tool chaining, robust error handling
Browse files
app.py
CHANGED
@@ -4,6 +4,7 @@ import requests
|
|
4 |
import inspect
|
5 |
import pandas as pd
|
6 |
from typing import Any
|
|
|
7 |
|
8 |
# (Keep Constants as is)
|
9 |
# --- Constants ---
|
@@ -281,13 +282,81 @@ Question:
|
|
281 |
Answer:
|
282 |
"""
|
283 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
284 |
# --- Refactored ModularGAIAAgent ---
|
285 |
class ModularGAIAAgent:
|
286 |
-
|
|
|
287 |
self.api_url = api_url
|
288 |
-
self.tools = tool_registry or TOOL_REGISTRY
|
289 |
self.reasoning_trace = []
|
290 |
self.file_cache = set(os.listdir('.'))
|
|
|
291 |
|
292 |
def fetch_questions(self, from_api=True, questions_path="Hugging Face Questions"):
|
293 |
"""Fetch questions from API or local file."""
|
@@ -357,15 +426,15 @@ class ModularGAIAAgent:
|
|
357 |
"""Analyze file and return context for the question."""
|
358 |
try:
|
359 |
if file_type == 'audio':
|
360 |
-
transcript = self.tools
|
361 |
self.reasoning_trace.append(f"Transcribed audio: {transcript[:100]}...")
|
362 |
return transcript
|
363 |
elif file_type == 'image':
|
364 |
-
caption = self.tools
|
365 |
self.reasoning_trace.append(f"Image caption: {caption}")
|
366 |
return caption
|
367 |
elif file_type == 'code':
|
368 |
-
result = self.tools
|
369 |
self.reasoning_trace.append(f"Code analysis result: {result}")
|
370 |
return result
|
371 |
elif file_type == 'excel':
|
@@ -400,41 +469,7 @@ class ModularGAIAAgent:
|
|
400 |
self.reasoning_trace.append(f"Analyze file error: {e}")
|
401 |
return None
|
402 |
|
403 |
-
def smart_tool_select(self, question, file_type=None):
|
404 |
-
"""Select the best tool(s) for the question, optionally using GPT-4.1 for planning."""
|
405 |
-
api_key = os.environ.get("OPENAI_API_KEY", "")
|
406 |
-
try:
|
407 |
-
if api_key:
|
408 |
-
plan_prompt = f"""
|
409 |
-
You are an expert AI agent. Given the following question and file type, suggest the best tool(s) to use from this list: {list(self.tools.keys())}.
|
410 |
-
Question: {question}
|
411 |
-
File type: {file_type}
|
412 |
-
Respond with a comma-separated list of tool names only, in order of use. If unsure, start with web_search_duckduckgo.
|
413 |
-
"""
|
414 |
-
plan = gpt4_chat(plan_prompt, api_key=api_key)
|
415 |
-
tool_names = [t.strip() for t in plan.split(',') if t.strip() in self.tools]
|
416 |
-
if tool_names:
|
417 |
-
return tool_names
|
418 |
-
except Exception as e:
|
419 |
-
logger.error(f"smart_tool_select planning error: {e}")
|
420 |
-
# Fallback: heuristic
|
421 |
-
if file_type == 'audio':
|
422 |
-
return ['asr_transcribe']
|
423 |
-
elif file_type == 'image':
|
424 |
-
return ['image_caption']
|
425 |
-
elif file_type == 'code':
|
426 |
-
return ['code_analysis']
|
427 |
-
elif file_type in ['excel', 'csv']:
|
428 |
-
return ['table_qa']
|
429 |
-
elif 'youtube.com' in question or 'youtu.be' in question:
|
430 |
-
return ['youtube_video_qa']
|
431 |
-
elif any(w in question.lower() for w in ['wikipedia', 'who', 'when', 'where', 'what', 'how', 'find', 'search']):
|
432 |
-
return ['web_search_duckduckgo']
|
433 |
-
else:
|
434 |
-
return ['llama3_chat']
|
435 |
-
|
436 |
def answer_question(self, question_obj):
|
437 |
-
"""Answer a question using the best tool(s) and context."""
|
438 |
self.reasoning_trace = []
|
439 |
q = question_obj["question"]
|
440 |
file_name = question_obj.get("file_name", "")
|
@@ -446,19 +481,23 @@ Respond with a comma-separated list of tool names only, in order of use. If unsu
|
|
446 |
if local_file:
|
447 |
file_type = self.detect_file_type(local_file)
|
448 |
file_content = self.analyze_file(local_file, file_type)
|
449 |
-
#
|
450 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
451 |
answer = None
|
452 |
-
context = file_content
|
453 |
for tool_name in tool_names:
|
454 |
-
tool = self.tools
|
455 |
try:
|
456 |
logger.info(f"Using tool: {tool_name} | Question: {q} | Context: {str(context)[:200]}")
|
457 |
if tool_name == 'web_search_duckduckgo':
|
458 |
context = tool(q)
|
459 |
answer = llama3_chat(build_prompt(context, q))
|
460 |
-
elif tool_name == 'gpt4_chat':
|
461 |
-
answer = tool(build_prompt(context, q))
|
462 |
elif tool_name == 'table_qa' and file_content:
|
463 |
answer = tool(q, file_content)
|
464 |
elif tool_name in ['asr_transcribe', 'image_caption', 'code_analysis'] and file_content:
|
@@ -466,7 +505,6 @@ Respond with a comma-separated list of tool names only, in order of use. If unsu
|
|
466 |
elif tool_name == 'youtube_video_qa':
|
467 |
answer = tool(q, q)
|
468 |
else:
|
469 |
-
# Always pass context if available
|
470 |
if context:
|
471 |
answer = llama3_chat(build_prompt(context, q))
|
472 |
else:
|
@@ -479,13 +517,7 @@ Respond with a comma-separated list of tool names only, in order of use. If unsu
|
|
479 |
continue
|
480 |
self.reasoning_trace.append(f"Tools used: {tool_names}")
|
481 |
self.reasoning_trace.append(f"Final answer: {answer}")
|
482 |
-
return
|
483 |
-
|
484 |
-
def format_answer(self, answer):
|
485 |
-
"""Strict GAIA: only the answer, no extra text, no prefix."""
|
486 |
-
if isinstance(answer, str):
|
487 |
-
return answer.strip().split('\n')[0]
|
488 |
-
return str(answer)
|
489 |
|
490 |
# --- Basic Agent Definition (now wraps ModularGAIAAgent) ---
|
491 |
class BasicAgent:
|
|
|
4 |
import inspect
|
5 |
import pandas as pd
|
6 |
from typing import Any
|
7 |
+
import re
|
8 |
|
9 |
# (Keep Constants as is)
|
10 |
# --- Constants ---
|
|
|
282 |
Answer:
|
283 |
"""
|
284 |
|
285 |
+
# --- Centralized Output Formatting & Normalization ---
|
286 |
+
def gaia_normalize_answer(answer):
|
287 |
+
"""Normalize answer for GAIA: remove units, articles, extra text, and ensure concise, factual output."""
|
288 |
+
if not isinstance(answer, str):
|
289 |
+
answer = str(answer)
|
290 |
+
# Remove common articles and units unless required
|
291 |
+
answer = answer.strip()
|
292 |
+
answer = re.sub(r"\b(the|a|an)\b", "", answer, flags=re.IGNORECASE)
|
293 |
+
answer = re.sub(r"\s+", " ", answer)
|
294 |
+
# Remove currency, percent, or units unless specified (GAIA rules)
|
295 |
+
answer = re.sub(r"\$|%|USD|dollars|euros|eur|\bpercent\b", "", answer, flags=re.IGNORECASE)
|
296 |
+
# Remove leading/trailing punctuation
|
297 |
+
answer = answer.strip(' .,:;\n\t')
|
298 |
+
return answer
|
299 |
+
|
300 |
+
# --- Reasoning Planner for Tool Chaining ---
|
301 |
+
def reasoning_planner(question, file_type, tools):
|
302 |
+
"""Plan the sequence of tools to use for a question. Uses LLM or heuristic."""
|
303 |
+
# Heuristic: if file_type is known, use the corresponding tool; else, use web search + LLM
|
304 |
+
if file_type == 'audio':
|
305 |
+
return ['asr_transcribe', 'llama3_chat']
|
306 |
+
elif file_type == 'image':
|
307 |
+
return ['image_caption', 'llama3_chat']
|
308 |
+
elif file_type == 'code':
|
309 |
+
return ['code_analysis', 'llama3_chat']
|
310 |
+
elif file_type in ['excel', 'csv']:
|
311 |
+
return ['table_qa']
|
312 |
+
elif 'youtube.com' in question or 'youtu.be' in question:
|
313 |
+
return ['youtube_video_qa']
|
314 |
+
elif any(w in question.lower() for w in ['wikipedia', 'who', 'when', 'where', 'what', 'how', 'find', 'search']):
|
315 |
+
return ['web_search_duckduckgo', 'llama3_chat']
|
316 |
+
else:
|
317 |
+
return ['llama3_chat']
|
318 |
+
|
319 |
+
# --- Improved RAG: Context Retrieval & Chunking ---
|
320 |
+
def retrieve_context(question, context_files, max_chunks=3):
|
321 |
+
"""Retrieve relevant context chunks from large files for RAG."""
|
322 |
+
# Simple keyword search for now; can be replaced with semantic search
|
323 |
+
relevant_chunks = []
|
324 |
+
for file_path in context_files:
|
325 |
+
try:
|
326 |
+
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
|
327 |
+
text = f.read()
|
328 |
+
# Split into chunks (e.g., 500 words)
|
329 |
+
chunks = [text[i:i+2000] for i in range(0, len(text), 2000)]
|
330 |
+
for chunk in chunks:
|
331 |
+
if any(word.lower() in chunk.lower() for word in question.split()):
|
332 |
+
relevant_chunks.append(chunk)
|
333 |
+
if len(relevant_chunks) >= max_chunks:
|
334 |
+
break
|
335 |
+
except Exception as e:
|
336 |
+
logger.error(f"retrieve_context error: {e}")
|
337 |
+
return '\n'.join(relevant_chunks)
|
338 |
+
|
339 |
+
# --- Modular Tool Registry & Chaining ---
|
340 |
+
class ToolRegistry:
|
341 |
+
"""Central registry for tools. Allows easy addition and chaining."""
|
342 |
+
def __init__(self, tools):
|
343 |
+
self.tools = tools
|
344 |
+
def get(self, name):
|
345 |
+
return self.tools.get(name)
|
346 |
+
def add(self, name, func):
|
347 |
+
self.tools[name] = func
|
348 |
+
def list(self):
|
349 |
+
return list(self.tools.keys())
|
350 |
+
|
351 |
# --- Refactored ModularGAIAAgent ---
|
352 |
class ModularGAIAAgent:
|
353 |
+
"""GAIA-compliant agent with robust reasoning, tool chaining, RAG, and output normalization."""
|
354 |
+
def __init__(self, api_url=DEFAULT_API_URL, tool_registry=None, context_files=None):
|
355 |
self.api_url = api_url
|
356 |
+
self.tools = ToolRegistry(tool_registry or TOOL_REGISTRY)
|
357 |
self.reasoning_trace = []
|
358 |
self.file_cache = set(os.listdir('.'))
|
359 |
+
self.context_files = context_files or []
|
360 |
|
361 |
def fetch_questions(self, from_api=True, questions_path="Hugging Face Questions"):
|
362 |
"""Fetch questions from API or local file."""
|
|
|
426 |
"""Analyze file and return context for the question."""
|
427 |
try:
|
428 |
if file_type == 'audio':
|
429 |
+
transcript = self.tools.get('asr_transcribe')(file_name)
|
430 |
self.reasoning_trace.append(f"Transcribed audio: {transcript[:100]}...")
|
431 |
return transcript
|
432 |
elif file_type == 'image':
|
433 |
+
caption = self.tools.get('image_caption')(file_name)
|
434 |
self.reasoning_trace.append(f"Image caption: {caption}")
|
435 |
return caption
|
436 |
elif file_type == 'code':
|
437 |
+
result = self.tools.get('code_analysis')(file_name)
|
438 |
self.reasoning_trace.append(f"Code analysis result: {result}")
|
439 |
return result
|
440 |
elif file_type == 'excel':
|
|
|
469 |
self.reasoning_trace.append(f"Analyze file error: {e}")
|
470 |
return None
|
471 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
472 |
def answer_question(self, question_obj):
|
|
|
473 |
self.reasoning_trace = []
|
474 |
q = question_obj["question"]
|
475 |
file_name = question_obj.get("file_name", "")
|
|
|
481 |
if local_file:
|
482 |
file_type = self.detect_file_type(local_file)
|
483 |
file_content = self.analyze_file(local_file, file_type)
|
484 |
+
# RAG: retrieve context if needed
|
485 |
+
rag_context = ''
|
486 |
+
if not file_content and self.context_files:
|
487 |
+
rag_context = retrieve_context(q, self.context_files)
|
488 |
+
if rag_context:
|
489 |
+
self.reasoning_trace.append(f"RAG context used: {rag_context[:200]}...")
|
490 |
+
# Reasoning planner: decide tool chain
|
491 |
+
tool_names = reasoning_planner(q, file_type, self.tools.list())
|
492 |
answer = None
|
493 |
+
context = file_content or rag_context
|
494 |
for tool_name in tool_names:
|
495 |
+
tool = self.tools.get(tool_name)
|
496 |
try:
|
497 |
logger.info(f"Using tool: {tool_name} | Question: {q} | Context: {str(context)[:200]}")
|
498 |
if tool_name == 'web_search_duckduckgo':
|
499 |
context = tool(q)
|
500 |
answer = llama3_chat(build_prompt(context, q))
|
|
|
|
|
501 |
elif tool_name == 'table_qa' and file_content:
|
502 |
answer = tool(q, file_content)
|
503 |
elif tool_name in ['asr_transcribe', 'image_caption', 'code_analysis'] and file_content:
|
|
|
505 |
elif tool_name == 'youtube_video_qa':
|
506 |
answer = tool(q, q)
|
507 |
else:
|
|
|
508 |
if context:
|
509 |
answer = llama3_chat(build_prompt(context, q))
|
510 |
else:
|
|
|
517 |
continue
|
518 |
self.reasoning_trace.append(f"Tools used: {tool_names}")
|
519 |
self.reasoning_trace.append(f"Final answer: {answer}")
|
520 |
+
return gaia_normalize_answer(answer), self.reasoning_trace
|
|
|
|
|
|
|
|
|
|
|
|
|
521 |
|
522 |
# --- Basic Agent Definition (now wraps ModularGAIAAgent) ---
|
523 |
class BasicAgent:
|