Spaces:
Runtime error
Runtime error
#!/usr/bin/env python3 | |
""" | |
🚀 SmoLAgents Bridge for GAIA System | |
Integrates smolagents framework with our existing tools for 60+ point performance boost | |
""" | |
import os | |
import logging | |
from typing import Optional | |
# Try to import smolagents | |
try: | |
from smolagents import CodeAgent, InferenceClientModel, tool, DuckDuckGoSearchTool | |
from smolagents.tools import VisitWebpageTool | |
SMOLAGENTS_AVAILABLE = True | |
except ImportError: | |
SMOLAGENTS_AVAILABLE = False | |
CodeAgent = None | |
tool = None | |
# Import our existing system and enhanced tools | |
from gaia_system import BasicAgent as FallbackAgent, UniversalMultimodalToolkit | |
try: | |
from enhanced_gaia_tools import EnhancedGAIATools | |
ENHANCED_TOOLS_AVAILABLE = True | |
except ImportError: | |
ENHANCED_TOOLS_AVAILABLE = False | |
logger = logging.getLogger(__name__) | |
class SmoLAgentsEnhancedAgent: | |
"""🚀 Enhanced GAIA agent powered by SmoLAgents framework""" | |
def __init__(self, hf_token: str = None, openai_key: str = None): | |
self.hf_token = hf_token or os.getenv('HF_TOKEN') | |
self.openai_key = openai_key or os.getenv('OPENAI_API_KEY') | |
if not SMOLAGENTS_AVAILABLE: | |
print("⚠️ SmoLAgents not available, using fallback system") | |
self.agent = FallbackAgent(hf_token, openai_key) | |
self.use_smolagents = False | |
return | |
self.use_smolagents = True | |
self.toolkit = UniversalMultimodalToolkit(self.hf_token, self.openai_key) | |
# Initialize enhanced tools if available | |
if ENHANCED_TOOLS_AVAILABLE: | |
self.enhanced_tools = EnhancedGAIATools(self.hf_token, self.openai_key) | |
print("✅ Enhanced GAIA tools loaded") | |
else: | |
self.enhanced_tools = None | |
print("⚠️ Enhanced GAIA tools not available") | |
# Create model with our priority system | |
self.model = self._create_priority_model() | |
# Create CodeAgent with our tools | |
self.agent = self._create_code_agent() | |
print("✅ SmoLAgents GAIA System initialized with enhanced tools") | |
def _create_priority_model(self): | |
"""Create model with Qwen3-235B-A22B priority""" | |
try: | |
# Priority 1: Qwen3-235B-A22B (Best for GAIA) | |
return InferenceClientModel( | |
provider="fireworks-ai", | |
api_key=self.hf_token, | |
model="Qwen/Qwen3-235B-A22B" | |
) | |
except: | |
try: | |
# Priority 2: DeepSeek-R1 | |
return InferenceClientModel( | |
model="deepseek-ai/DeepSeek-R1", | |
token=self.hf_token | |
) | |
except: | |
# Fallback | |
return InferenceClientModel( | |
model="meta-llama/Llama-3.1-8B-Instruct", | |
token=self.hf_token | |
) | |
def _create_code_agent(self): | |
"""Create CodeAgent with essential tools + enhanced tools""" | |
# Create our custom tools | |
calculator_tool = self._create_calculator_tool() | |
image_tool = self._create_image_analysis_tool() | |
download_tool = self._create_file_download_tool() | |
pdf_tool = self._create_pdf_tool() | |
tools = [ | |
DuckDuckGoSearchTool(), | |
VisitWebpageTool(), | |
calculator_tool, | |
image_tool, | |
download_tool, | |
pdf_tool, | |
] | |
# Add enhanced tools if available | |
if self.enhanced_tools: | |
enhanced_docx_tool = self._create_enhanced_docx_tool() | |
enhanced_excel_tool = self._create_enhanced_excel_tool() | |
enhanced_csv_tool = self._create_enhanced_csv_tool() | |
enhanced_browse_tool = self._create_enhanced_browse_tool() | |
enhanced_gaia_download_tool = self._create_enhanced_gaia_download_tool() | |
tools.extend([ | |
enhanced_docx_tool, | |
enhanced_excel_tool, | |
enhanced_csv_tool, | |
enhanced_browse_tool, | |
enhanced_gaia_download_tool, | |
]) | |
print(f"✅ Added {len(tools)} tools including enhanced capabilities") | |
return CodeAgent( | |
tools=tools, | |
model=self.model, | |
system_prompt=self._get_gaia_prompt(), | |
max_steps=3, | |
verbosity=0 | |
) | |
def _get_gaia_prompt(self): | |
"""GAIA-optimized system prompt with enhanced tools""" | |
enhanced_tools_info = "" | |
if self.enhanced_tools: | |
enhanced_tools_info = """ | |
- read_docx: Read Microsoft Word documents | |
- read_excel: Read Excel spreadsheets | |
- read_csv: Read CSV files with advanced parsing | |
- browse_with_js: Enhanced web browsing with JavaScript | |
- download_gaia_file: Enhanced GAIA file downloads with auto-processing""" | |
return f"""You are a GAIA benchmark expert. Use tools to solve questions step-by-step. | |
CRITICAL: Provide ONLY the final answer - no explanations. | |
Format: number OR few words OR comma-separated list | |
No units unless specified. No articles for strings. | |
Available tools: | |
- DuckDuckGoSearchTool: Search the web | |
- VisitWebpageTool: Visit URLs | |
- calculator: Mathematical calculations | |
- analyze_image: Analyze images | |
- download_file: Download GAIA files | |
- read_pdf: Extract PDF text{enhanced_tools_info} | |
Enhanced GAIA compliance: Use the most appropriate tool for each task.""" | |
def _create_calculator_tool(self): | |
"""🧮 Mathematical calculations""" | |
def calculator(expression: str) -> str: | |
"""Perform mathematical calculations | |
Args: | |
expression: Mathematical expression to evaluate | |
""" | |
return self.toolkit.calculator(expression) | |
return calculator | |
def _create_image_analysis_tool(self): | |
"""🖼️ Image analysis""" | |
def analyze_image(image_path: str, question: str = "") -> str: | |
"""Analyze images and answer questions | |
Args: | |
image_path: Path to image file | |
question: Question about the image | |
""" | |
return self.toolkit.analyze_image(image_path, question) | |
return analyze_image | |
def _create_file_download_tool(self): | |
"""📥 File downloads""" | |
def download_file(url: str = "", task_id: str = "") -> str: | |
"""Download files from URLs or GAIA tasks | |
Args: | |
url: URL to download from | |
task_id: GAIA task ID | |
""" | |
return self.toolkit.download_file(url, task_id) | |
return download_file | |
def _create_pdf_tool(self): | |
"""📄 PDF reading""" | |
def read_pdf(file_path: str) -> str: | |
"""Extract text from PDF documents | |
Args: | |
file_path: Path to PDF file | |
""" | |
return self.toolkit.read_pdf(file_path) | |
return read_pdf | |
def _create_enhanced_docx_tool(self): | |
"""📄 Enhanced Word document reading""" | |
def read_docx(file_path: str) -> str: | |
"""Read Microsoft Word documents with enhanced processing | |
Args: | |
file_path: Path to DOCX file | |
""" | |
if self.enhanced_tools: | |
return self.enhanced_tools.read_docx(file_path) | |
return "❌ Enhanced DOCX reading not available" | |
return read_docx | |
def _create_enhanced_excel_tool(self): | |
"""📊 Enhanced Excel reading""" | |
def read_excel(file_path: str, sheet_name: str = None) -> str: | |
"""Read Excel spreadsheets with advanced parsing | |
Args: | |
file_path: Path to Excel file | |
sheet_name: Optional sheet name to read | |
""" | |
if self.enhanced_tools: | |
return self.enhanced_tools.read_excel(file_path, sheet_name) | |
return "❌ Enhanced Excel reading not available" | |
return read_excel | |
def _create_enhanced_csv_tool(self): | |
"""📋 Enhanced CSV reading""" | |
def read_csv(file_path: str) -> str: | |
"""Read CSV files with enhanced processing | |
Args: | |
file_path: Path to CSV file | |
""" | |
if self.enhanced_tools: | |
return self.enhanced_tools.read_csv(file_path) | |
return "❌ Enhanced CSV reading not available" | |
return read_csv | |
def _create_enhanced_browse_tool(self): | |
"""🌐 Enhanced web browsing""" | |
def browse_with_js(url: str) -> str: | |
"""Enhanced web browsing with JavaScript support | |
Args: | |
url: URL to browse | |
""" | |
if self.enhanced_tools: | |
return self.enhanced_tools.browse_with_js(url) | |
return "❌ Enhanced browsing not available" | |
return browse_with_js | |
def _create_enhanced_gaia_download_tool(self): | |
"""📥 Enhanced GAIA file downloads""" | |
def download_gaia_file(task_id: str, file_name: str = None) -> str: | |
"""Enhanced GAIA file download with auto-processing | |
Args: | |
task_id: GAIA task identifier | |
file_name: Optional filename override | |
""" | |
if self.enhanced_tools: | |
return self.enhanced_tools.download_gaia_file(task_id, file_name) | |
return "❌ Enhanced GAIA downloads not available" | |
return download_gaia_file | |
def query(self, question: str) -> str: | |
"""Process question with SmoLAgents or fallback""" | |
if not self.use_smolagents: | |
return self.agent.query(question) | |
try: | |
print(f"🚀 Processing with SmoLAgents: {question[:80]}...") | |
response = self.agent.run(question) | |
cleaned = self._clean_response(response) | |
print(f"✅ SmoLAgents result: {cleaned}") | |
return cleaned | |
except Exception as e: | |
print(f"⚠️ SmoLAgents error: {e}, falling back to original system") | |
# Fallback to original system | |
fallback = FallbackAgent(self.hf_token, self.openai_key) | |
return fallback.query(question) | |
def _clean_response(self, response: str) -> str: | |
"""Clean response for GAIA compliance""" | |
if not response: | |
return "Unable to provide answer" | |
response = response.strip() | |
# Remove common prefixes | |
prefixes = ["the answer is:", "answer:", "result:", "final answer:", "solution:"] | |
response_lower = response.lower() | |
for prefix in prefixes: | |
if response_lower.startswith(prefix): | |
response = response[len(prefix):].strip() | |
break | |
return response.rstrip('.') | |
def clean_for_api_submission(self, response: str) -> str: | |
"""Clean response for GAIA API submission (compatibility method)""" | |
return self._clean_response(response) | |
def __call__(self, question: str) -> str: | |
"""Make agent callable""" | |
return self.query(question) | |
def cleanup(self): | |
"""Clean up resources""" | |
if hasattr(self.toolkit, 'cleanup'): | |
self.toolkit.cleanup() | |
def create_enhanced_agent(hf_token: str = None, openai_key: str = None) -> SmoLAgentsEnhancedAgent: | |
"""Factory function for enhanced agent""" | |
return SmoLAgentsEnhancedAgent(hf_token, openai_key) | |
if __name__ == "__main__": | |
# Quick test | |
print("🧪 Testing SmoLAgents Bridge...") | |
agent = SmoLAgentsEnhancedAgent() | |
test_questions = [ | |
"What is 5 + 3?", | |
"What is the capital of France?", | |
"How many sides does a triangle have?" | |
] | |
for q in test_questions: | |
print(f"\nQ: {q}") | |
print(f"A: {agent.query(q)}") | |
print("\n✅ Bridge test completed!") |