Omachoko commited on
Commit
83a3deb
·
1 Parent(s): a9d900f

Final: robust GAIA agent with advanced tool registry, GPT-4.1, web search, strict output, and full multi-modal support

Browse files
Files changed (1) hide show
  1. app.py +90 -39
app.py CHANGED
@@ -24,6 +24,7 @@ from huggingface_hub import InferenceClient
24
  import cv2
25
  import torch
26
  from bs4 import BeautifulSoup
 
27
 
28
  logging.basicConfig(filename='gaia_agent.log', level=logging.INFO, format='%(asctime)s %(levelname)s:%(message)s')
29
  logger = logging.getLogger(__name__)
@@ -198,6 +199,37 @@ def youtube_video_qa(youtube_url, question):
198
  logging.error(f"YouTube video QA error: {e}")
199
  return f"Video analysis error: {e}"
200
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
201
  TOOL_REGISTRY = {
202
  "llama3_chat": llama3_chat,
203
  "mixtral_chat": mixtral_chat,
@@ -207,6 +239,8 @@ TOOL_REGISTRY = {
207
  "image_caption": image_caption,
208
  "code_analysis": code_analysis,
209
  "youtube_video_qa": youtube_video_qa,
 
 
210
  }
211
 
212
  class ModularGAIAAgent:
@@ -304,63 +338,80 @@ class ModularGAIAAgent:
304
  self.reasoning_trace.append(f"Unknown file type: {file_name}")
305
  return None
306
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
307
  def answer_question(self, question_obj):
308
  self.reasoning_trace = []
309
  q = question_obj["question"]
310
  file_name = question_obj.get("file_name", "")
311
  file_content = None
312
  file_type = None
313
- # YouTube video question detection
314
- if "youtube.com" in q or "youtu.be" in q:
315
- url = None
316
- for word in q.split():
317
- if "youtube.com" in word or "youtu.be" in word:
318
- url = word.strip().strip(',')
319
- break
320
- if url:
321
- answer = self.tools['youtube_video_qa'](url, q)
322
- self.reasoning_trace.append(f"YouTube video analyzed: {url}")
323
- self.reasoning_trace.append(f"Final answer: {answer}")
324
- return self.format_answer(answer), self.reasoning_trace
325
  if file_name:
326
  file_id = file_name.split('.')[0]
327
  local_file = self.download_file(file_id, file_name)
328
  if local_file:
329
  file_type = self.detect_file_type(local_file)
330
  file_content = self.analyze_file(local_file, file_type)
331
- # Plan: choose tool based on question and file
332
- if file_type == 'audio' or file_type == 'text':
333
- if file_content:
334
- answer = self.tools['extractive_qa'](q, file_content)
335
- else:
336
- answer = self.tools['llama3_chat'](q)
337
- elif file_type == 'excel' or file_type == 'csv':
338
- if file_content:
339
- answer = self.tools['table_qa'](q, file_content)
 
 
 
 
 
 
 
 
 
340
  else:
341
- answer = self.tools['llama3_chat'](q)
342
- elif file_type == 'image':
343
- if file_content:
344
- answer = self.tools['llama3_chat'](f"{q}\nImage description: {file_content}")
345
- else:
346
- answer = self.tools['llama3_chat'](q)
347
- elif file_type == 'code':
348
- answer = file_content
349
- else:
350
- answer = self.tools['llama3_chat'](q)
351
  self.reasoning_trace.append(f"Final answer: {answer}")
352
  return self.format_answer(answer), self.reasoning_trace
353
 
354
  def format_answer(self, answer):
 
355
  if isinstance(answer, str):
356
- answer = answer.strip().rstrip('.')
357
- for prefix in ['answer:', 'result:', 'the answer is', 'final answer:', 'response:']:
358
- if answer.lower().startswith(prefix):
359
- answer = answer[len(prefix):].strip()
360
- import re
361
- answer = re.sub(r'\b(the|a|an)\b ', '', answer, flags=re.IGNORECASE)
362
- answer = answer.strip().rstrip('.')
363
- return answer
364
 
365
  # --- Basic Agent Definition (now wraps ModularGAIAAgent) ---
366
  class BasicAgent:
 
24
  import cv2
25
  import torch
26
  from bs4 import BeautifulSoup
27
+ import openai
28
 
29
  logging.basicConfig(filename='gaia_agent.log', level=logging.INFO, format='%(asctime)s %(levelname)s:%(message)s')
30
  logger = logging.getLogger(__name__)
 
199
  logging.error(f"YouTube video QA error: {e}")
200
  return f"Video analysis error: {e}"
201
 
202
+ def web_search_duckduckgo(query, max_results=5):
203
+ """DuckDuckGo web search tool: returns top snippets and URLs."""
204
+ try:
205
+ import duckduckgo_search
206
+ results = duckduckgo_search.DuckDuckGoSearch().search(query, max_results=max_results)
207
+ snippets = []
208
+ for r in results:
209
+ snippet = f"Title: {r['title']}\nSnippet: {r['body']}\nURL: {r['href']}"
210
+ snippets.append(snippet)
211
+ return '\n---\n'.join(snippets)
212
+ except Exception as e:
213
+ logging.error(f"web_search_duckduckgo error: {e}")
214
+ return f"Web search error: {e}"
215
+
216
+ def gpt4_chat(prompt, api_key=None):
217
+ """OpenAI GPT-4.1 chat completion."""
218
+ try:
219
+ api_key = api_key or os.environ.get("OPENAI_API_KEY", "")
220
+ if not api_key:
221
+ return "No OpenAI API key provided."
222
+ response = openai.ChatCompletion.create(
223
+ model="gpt-4-1106-preview",
224
+ messages=[{"role": "system", "content": "You are a general AI assistant. Answer using as few words as possible, in the required format. Use tools as needed, and only output the answer."},
225
+ {"role": "user", "content": prompt}],
226
+ api_key=api_key,
227
+ )
228
+ return response.choices[0].message['content'].strip()
229
+ except Exception as e:
230
+ logging.error(f"gpt4_chat error: {e}")
231
+ return f"GPT-4 error: {e}"
232
+
233
  TOOL_REGISTRY = {
234
  "llama3_chat": llama3_chat,
235
  "mixtral_chat": mixtral_chat,
 
239
  "image_caption": image_caption,
240
  "code_analysis": code_analysis,
241
  "youtube_video_qa": youtube_video_qa,
242
+ "web_search_duckduckgo": web_search_duckduckgo,
243
+ "gpt4_chat": gpt4_chat,
244
  }
245
 
246
  class ModularGAIAAgent:
 
338
  self.reasoning_trace.append(f"Unknown file type: {file_name}")
339
  return None
340
 
341
+ def smart_tool_select(self, question, file_type=None):
342
+ """Select the best tool(s) for the question, optionally using GPT-4.1 for planning."""
343
+ # Use GPT-4.1 to suggest a tool if available
344
+ api_key = os.environ.get("OPENAI_API_KEY", "")
345
+ if api_key:
346
+ plan_prompt = f"""
347
+ You are an expert AI agent. Given the following question and file type, suggest the best tool(s) to use from this list: {list(self.tools.keys())}.
348
+ Question: {question}
349
+ File type: {file_type}
350
+ Respond with a comma-separated list of tool names only, in order of use. If unsure, start with web_search_duckduckgo.
351
+ """
352
+ plan = gpt4_chat(plan_prompt, api_key=api_key)
353
+ tool_names = [t.strip() for t in plan.split(',') if t.strip() in self.tools]
354
+ if tool_names:
355
+ return tool_names
356
+ # Fallback: heuristic
357
+ if file_type == 'audio':
358
+ return ['asr_transcribe']
359
+ elif file_type == 'image':
360
+ return ['image_caption']
361
+ elif file_type == 'code':
362
+ return ['code_analysis']
363
+ elif file_type in ['excel', 'csv']:
364
+ return ['table_qa']
365
+ elif 'youtube.com' in question or 'youtu.be' in question:
366
+ return ['youtube_video_qa']
367
+ elif any(w in question.lower() for w in ['wikipedia', 'who', 'when', 'where', 'what', 'how', 'find', 'search']):
368
+ return ['web_search_duckduckgo']
369
+ else:
370
+ return ['llama3_chat']
371
+
372
  def answer_question(self, question_obj):
373
  self.reasoning_trace = []
374
  q = question_obj["question"]
375
  file_name = question_obj.get("file_name", "")
376
  file_content = None
377
  file_type = None
 
 
 
 
 
 
 
 
 
 
 
 
378
  if file_name:
379
  file_id = file_name.split('.')[0]
380
  local_file = self.download_file(file_id, file_name)
381
  if local_file:
382
  file_type = self.detect_file_type(local_file)
383
  file_content = self.analyze_file(local_file, file_type)
384
+ # Smart tool selection
385
+ tool_names = self.smart_tool_select(q, file_type)
386
+ answer = None
387
+ context = None
388
+ for tool_name in tool_names:
389
+ tool = self.tools[tool_name]
390
+ if tool_name == 'web_search_duckduckgo':
391
+ context = tool(q)
392
+ # Use LLM to synthesize answer from snippets
393
+ answer = llama3_chat(f"Answer the following question using ONLY the information below.\nQuestion: {q}\nSnippets:\n{context}\nAnswer:")
394
+ elif tool_name == 'gpt4_chat':
395
+ answer = tool(q)
396
+ elif tool_name == 'table_qa' and file_content:
397
+ answer = tool(q, file_content)
398
+ elif tool_name in ['asr_transcribe', 'image_caption', 'code_analysis'] and file_content:
399
+ answer = tool(file_name)
400
+ elif tool_name == 'youtube_video_qa':
401
+ answer = tool(q, q)
402
  else:
403
+ answer = tool(q)
404
+ if answer:
405
+ break
406
+ self.reasoning_trace.append(f"Tools used: {tool_names}")
 
 
 
 
 
 
407
  self.reasoning_trace.append(f"Final answer: {answer}")
408
  return self.format_answer(answer), self.reasoning_trace
409
 
410
  def format_answer(self, answer):
411
+ # Strict GAIA: only the answer, no extra text, no prefix
412
  if isinstance(answer, str):
413
+ return answer.strip().split('\n')[0]
414
+ return str(answer)
 
 
 
 
 
 
415
 
416
  # --- Basic Agent Definition (now wraps ModularGAIAAgent) ---
417
  class BasicAgent: