Spaces:
Runtime error
Runtime error
#!/usr/bin/env python3 | |
""" | |
๐ SmoLAgents Bridge for GAIA System | |
Integrates smolagents framework with our existing tools for 60+ point performance boost | |
""" | |
import os | |
import logging | |
from typing import Optional | |
# Try to import smolagents | |
try: | |
from smolagents import CodeAgent, InferenceClientModel, tool, DuckDuckGoSearchTool | |
from smolagents.tools import VisitWebpageTool | |
SMOLAGENTS_AVAILABLE = True | |
except ImportError: | |
SMOLAGENTS_AVAILABLE = False | |
CodeAgent = None | |
tool = None | |
# Import our existing system | |
from gaia_system import BasicAgent as FallbackAgent, UniversalMultimodalToolkit | |
logger = logging.getLogger(__name__) | |
class SmoLAgentsEnhancedAgent: | |
"""๐ Enhanced GAIA agent powered by SmoLAgents framework""" | |
def __init__(self, hf_token: str = None, openai_key: str = None): | |
self.hf_token = hf_token or os.getenv('HF_TOKEN') | |
self.openai_key = openai_key or os.getenv('OPENAI_API_KEY') | |
if not SMOLAGENTS_AVAILABLE: | |
print("โ ๏ธ SmoLAgents not available, using fallback system") | |
self.agent = FallbackAgent(hf_token, openai_key) | |
self.use_smolagents = False | |
return | |
self.use_smolagents = True | |
self.toolkit = UniversalMultimodalToolkit(self.hf_token, self.openai_key) | |
# Create model with our priority system | |
self.model = self._create_priority_model() | |
# Create CodeAgent with our tools | |
self.agent = self._create_code_agent() | |
print("โ SmoLAgents GAIA System initialized") | |
def _create_priority_model(self): | |
"""Create model with Qwen3-235B-A22B priority""" | |
try: | |
# Priority 1: Qwen3-235B-A22B (Best for GAIA) | |
return InferenceClientModel( | |
provider="fireworks-ai", | |
api_key=self.hf_token, | |
model="Qwen/Qwen3-235B-A22B" | |
) | |
except: | |
try: | |
# Priority 2: DeepSeek-R1 | |
return InferenceClientModel( | |
model="deepseek-ai/DeepSeek-R1", | |
token=self.hf_token | |
) | |
except: | |
# Fallback | |
return InferenceClientModel( | |
model="meta-llama/Llama-3.1-8B-Instruct", | |
token=self.hf_token | |
) | |
def _create_code_agent(self): | |
"""Create CodeAgent with essential tools""" | |
# Create our custom tools | |
calculator_tool = self._create_calculator_tool() | |
image_tool = self._create_image_analysis_tool() | |
download_tool = self._create_file_download_tool() | |
pdf_tool = self._create_pdf_tool() | |
tools = [ | |
DuckDuckGoSearchTool(), | |
VisitWebpageTool(), | |
calculator_tool, | |
image_tool, | |
download_tool, | |
pdf_tool, | |
] | |
return CodeAgent( | |
tools=tools, | |
model=self.model, | |
system_prompt=self._get_gaia_prompt(), | |
max_steps=3, | |
verbosity=0 | |
) | |
def _get_gaia_prompt(self): | |
"""GAIA-optimized system prompt""" | |
return """You are a GAIA benchmark expert. Use tools to solve questions step-by-step. | |
CRITICAL: Provide ONLY the final answer - no explanations. | |
Format: number OR few words OR comma-separated list | |
No units unless specified. No articles for strings. | |
Available tools: | |
- DuckDuckGoSearchTool: Search the web | |
- VisitWebpageTool: Visit URLs | |
- calculator: Mathematical calculations | |
- analyze_image: Analyze images | |
- download_file: Download GAIA files | |
- read_pdf: Extract PDF text""" | |
def _create_calculator_tool(self): | |
"""๐งฎ Mathematical calculations""" | |
def calculator(expression: str) -> str: | |
"""Perform mathematical calculations | |
Args: | |
expression: Mathematical expression to evaluate | |
""" | |
return self.toolkit.calculator(expression) | |
return calculator | |
def _create_image_analysis_tool(self): | |
"""๐ผ๏ธ Image analysis""" | |
def analyze_image(image_path: str, question: str = "") -> str: | |
"""Analyze images and answer questions | |
Args: | |
image_path: Path to image file | |
question: Question about the image | |
""" | |
return self.toolkit.analyze_image(image_path, question) | |
return analyze_image | |
def _create_file_download_tool(self): | |
"""๐ฅ File downloads""" | |
def download_file(url: str = "", task_id: str = "") -> str: | |
"""Download files from URLs or GAIA tasks | |
Args: | |
url: URL to download from | |
task_id: GAIA task ID | |
""" | |
return self.toolkit.download_file(url, task_id) | |
return download_file | |
def _create_pdf_tool(self): | |
"""๐ PDF reading""" | |
def read_pdf(file_path: str) -> str: | |
"""Extract text from PDF documents | |
Args: | |
file_path: Path to PDF file | |
""" | |
return self.toolkit.read_pdf(file_path) | |
return read_pdf | |
def query(self, question: str) -> str: | |
"""Process question with SmoLAgents or fallback""" | |
if not self.use_smolagents: | |
return self.agent.query(question) | |
try: | |
print(f"๐ Processing with SmoLAgents: {question[:80]}...") | |
response = self.agent.run(question) | |
cleaned = self._clean_response(response) | |
print(f"โ SmoLAgents result: {cleaned}") | |
return cleaned | |
except Exception as e: | |
print(f"โ ๏ธ SmoLAgents error: {e}, falling back to original system") | |
# Fallback to original system | |
fallback = FallbackAgent(self.hf_token, self.openai_key) | |
return fallback.query(question) | |
def _clean_response(self, response: str) -> str: | |
"""Clean response for GAIA compliance""" | |
if not response: | |
return "Unable to provide answer" | |
response = response.strip() | |
# Remove common prefixes | |
prefixes = ["the answer is:", "answer:", "result:", "final answer:", "solution:"] | |
response_lower = response.lower() | |
for prefix in prefixes: | |
if response_lower.startswith(prefix): | |
response = response[len(prefix):].strip() | |
break | |
return response.rstrip('.') | |
def clean_for_api_submission(self, response: str) -> str: | |
"""Clean response for GAIA API submission (compatibility method)""" | |
return self._clean_response(response) | |
def __call__(self, question: str) -> str: | |
"""Make agent callable""" | |
return self.query(question) | |
def cleanup(self): | |
"""Clean up resources""" | |
if hasattr(self.toolkit, 'cleanup'): | |
self.toolkit.cleanup() | |
def create_enhanced_agent(hf_token: str = None, openai_key: str = None) -> SmoLAgentsEnhancedAgent: | |
"""Factory function for enhanced agent""" | |
return SmoLAgentsEnhancedAgent(hf_token, openai_key) | |
if __name__ == "__main__": | |
# Quick test | |
print("๐งช Testing SmoLAgents Bridge...") | |
agent = SmoLAgentsEnhancedAgent() | |
test_questions = [ | |
"What is 5 + 3?", | |
"What is the capital of France?", | |
"How many sides does a triangle have?" | |
] | |
for q in test_questions: | |
print(f"\nQ: {q}") | |
print(f"A: {agent.query(q)}") | |
print("\nโ Bridge test completed!") |