ShallowCodeResearch / mcp_hub /sandbox_pool.py
HallD's picture
Upload 60 files
df2b222 verified
"""
Warm Sandbox Pool for Modal - Async Queue-Based Implementation
This module provides a pre-warmed pool of Modal sandboxes to reduce cold-start latency.
"""
import asyncio
import time
from typing import Optional, Dict, Any
from contextlib import asynccontextmanager
from dataclasses import dataclass
from enum import Enum
import modal
from mcp_hub.logging_config import logger
from mcp_hub.exceptions import CodeExecutionError
class SandboxHealth(Enum):
"""Sandbox health status."""
HEALTHY = "healthy"
UNHEALTHY = "unhealthy"
UNKNOWN = "unknown"
@dataclass
class PooledSandbox:
"""Container for a pooled sandbox with metadata."""
sandbox: modal.Sandbox
created_at: float
last_used: float
health: SandboxHealth = SandboxHealth.UNKNOWN
use_count: int = 0
class WarmSandboxPool:
"""Async queue-based warm sandbox pool with health checking."""
def __init__(
self,
app: modal.App,
image: modal.Image,
pool_size: int = 3,
max_age_seconds: int = 300, # 5 minutes
max_uses_per_sandbox: int = 10,
health_check_interval: int = 60, # 1 minute
):
self.app = app
self.image = image
self.pool_size = pool_size
self.max_age_seconds = max_age_seconds
self.max_uses_per_sandbox = max_uses_per_sandbox
self.health_check_interval = health_check_interval
# Queue to hold available sandboxes
self._sandbox_queue: asyncio.Queue[PooledSandbox] = asyncio.Queue(maxsize=pool_size)
# Background tasks
self._warmup_task: Optional[asyncio.Task] = None
self._health_check_task: Optional[asyncio.Task] = None
self._cleanup_task: Optional[asyncio.Task] = None
# Pool statistics
self._stats = {
"created": 0,
"reused": 0,
"recycled": 0,
"health_checks": 0,
"failures": 0
}
# Health tracking for better error recovery
self._consecutive_failures = 0
self._last_successful_creation = time.time()
self._pool_reset_threshold = 5 # Reset pool after 5 consecutive failures
self._running = False
async def start(self):
"""Start the pool and background tasks."""
if self._running:
return
self._running = True
logger.info(f"Starting warm sandbox pool with {self.pool_size} sandboxes")
# Start background tasks
self._warmup_task = asyncio.create_task(self._warmup_pool())
self._health_check_task = asyncio.create_task(self._health_check_loop())
self._cleanup_task = asyncio.create_task(self._cleanup_loop())
# Wait for initial warmup
await asyncio.sleep(1) # Give warmup a moment to start
async def stop(self):
"""Stop the pool and cleanup resources."""
if not self._running:
return
self._running = False
logger.info("Stopping warm sandbox pool")
# Cancel background tasks
for task in [self._warmup_task, self._health_check_task, self._cleanup_task]:
if task and not task.done():
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
# Cleanup remaining sandboxes
while not self._sandbox_queue.empty():
try:
pooled_sb = self._sandbox_queue.get_nowait()
await self._terminate_sandbox(pooled_sb.sandbox)
except asyncio.QueueEmpty:
break
@asynccontextmanager
async def get_sandbox(self, timeout: float = 5.0):
pooled_sb = None
created_new = False
try:
# Check if we need to reset the pool due to consecutive failures
if self._consecutive_failures >= self._pool_reset_threshold:
logger.warning(f"Pool has {self._consecutive_failures} consecutive failures, attempting reset")
await self._emergency_pool_reset()
# Try to get a warm sandbox from the pool, retry if not alive
max_retries = 3 # Increased retries for better reliability
for attempt in range(max_retries):
try:
# Try to get from pool first
pooled_sb = await asyncio.wait_for(self._sandbox_queue.get(), timeout=timeout)
# Check if the sandbox is alive
alive = await self._is_sandbox_alive(pooled_sb.sandbox)
if not alive:
logger.info(f"Got dead sandbox from pool on attempt {attempt + 1}, terminating and trying next.")
await self._terminate_sandbox(pooled_sb.sandbox)
pooled_sb = None
continue # Try again
# Sandbox is alive, use it
pooled_sb.last_used = time.time()
pooled_sb.use_count += 1
self._stats["reused"] += 1
self._consecutive_failures = 0 # Reset failure counter on success
break
except asyncio.TimeoutError:
# Pool empty or taking too long, create a new one
logger.info(f"Pool timeout on attempt {attempt + 1}, creating new sandbox")
try:
sandbox = await self._create_sandbox()
pooled_sb = PooledSandbox(
sandbox=sandbox,
created_at=time.time(),
last_used=time.time(),
use_count=1
)
created_new = True
self._stats["created"] += 1
self._consecutive_failures = 0 # Reset failure counter on success
self._last_successful_creation = time.time()
break
except Exception as create_error:
logger.error(f"Failed to create sandbox on attempt {attempt + 1}: {create_error}")
self._consecutive_failures += 1
if attempt == max_retries - 1: # Last attempt
raise CodeExecutionError(f"Failed to create sandbox after {max_retries} attempts: {create_error}")
await asyncio.sleep(2 ** attempt) # Exponential backoff
else:
self._consecutive_failures += 1
raise CodeExecutionError("Could not obtain a live sandbox from the pool after all retry attempts.")
logger.info(f"Yielding sandbox of type from sandbox_pool: {type(pooled_sb.sandbox)}")
yield pooled_sb.sandbox
except Exception as e:
logger.error(f"Error getting sandbox: {e}")
self._stats["failures"] += 1
self._consecutive_failures += 1
raise CodeExecutionError(f"Failed to get sandbox: {e}")
finally:
if pooled_sb:
should_recycle = (
not created_new and
self._should_recycle_sandbox(pooled_sb) and
self._running
)
if should_recycle:
# Double-check sandbox is alive and functional before returning to pool
if await self._is_sandbox_alive(pooled_sb.sandbox):
# Additional check: try a quick execution to ensure sandbox is fully functional
try:
await asyncio.wait_for(
asyncio.get_event_loop().run_in_executor(
None,
lambda: pooled_sb.sandbox.exec("python", "-c", "import sys; print('ready')", timeout=2)
),
timeout=3.0
)
# Sandbox is healthy and functional - return to pool
try:
self._sandbox_queue.put_nowait(pooled_sb)
logger.debug("Returned healthy sandbox to pool")
except asyncio.QueueFull:
# Pool is full - terminate excess sandbox
await self._terminate_sandbox(pooled_sb.sandbox)
logger.debug("Pool full, terminated excess sandbox")
except Exception as e:
# Sandbox failed functional test - terminate it
logger.debug(f"Sandbox failed functional test, terminating: {e}")
await self._terminate_sandbox(pooled_sb.sandbox)
else:
# Sandbox is dead - terminate it
logger.debug("Sandbox is dead, terminating instead of recycling")
await self._terminate_sandbox(pooled_sb.sandbox)
else:
# Should not recycle - terminate sandbox
await self._terminate_sandbox(pooled_sb.sandbox)
if not created_new:
self._stats["recycled"] += 1
logger.debug("Terminated sandbox (exceeded recycle criteria)")
async def _create_sandbox(self) -> modal.Sandbox:
"""Create a new Modal sandbox with timeout protection."""
try:
# Add timeout protection for sandbox creation
sandbox_creation = asyncio.get_event_loop().run_in_executor(
None,
lambda: modal.Sandbox.create(
app=self.app,
image=self.image,
cpu=2.0,
memory=1024,
timeout=35
)
)
# Wait for sandbox creation with timeout
sandbox = await asyncio.wait_for(sandbox_creation, timeout=120) # 2 minute timeout
logger.debug(f"Created new sandbox of type: {type(sandbox)}")
return sandbox
except asyncio.TimeoutError:
logger.error("Sandbox creation timed out after 2 minutes")
raise Exception("Sandbox creation timed out - Modal may be experiencing issues")
except Exception as e:
logger.error(f"Failed to create sandbox: {e}")
raise
async def _terminate_sandbox(self, sandbox: modal.Sandbox):
"""Safely terminate a sandbox with better error handling."""
try:
# Check if sandbox is still responsive before termination
if hasattr(sandbox, '_terminated') and sandbox._terminated:
logger.debug("Sandbox already terminated")
return
# Use asyncio timeout for termination
await asyncio.wait_for(
asyncio.get_event_loop().run_in_executor(None, sandbox.terminate),
timeout=10.0 # 10 second timeout for termination
)
logger.debug("Terminated sandbox successfully")
except asyncio.TimeoutError:
logger.warning("Sandbox termination timed out - may be unresponsive")
except Exception as e:
# Log the error but don't fail - sandbox may already be dead
logger.warning(f"Failed to terminate sandbox (may already be dead): {e}")
# Mark sandbox as terminated to avoid repeated attempts
if hasattr(sandbox, '_terminated'):
sandbox._terminated = True
def _should_recycle_sandbox(self, pooled_sb: PooledSandbox) -> bool:
"""Determine if a sandbox should be recycled back to the pool."""
now = time.time()
# Check age
if now - pooled_sb.created_at > self.max_age_seconds:
logger.debug("Sandbox too old, not recycling")
return False
# Check usage count
if pooled_sb.use_count >= self.max_uses_per_sandbox:
logger.debug("Sandbox used too many times, not recycling")
return False
# Check health (if we've checked it)
if pooled_sb.health == SandboxHealth.UNHEALTHY:
logger.debug("Sandbox unhealthy, not recycling")
return False
return True
async def _warmup_pool(self):
"""Background task to maintain warm sandboxes in the pool with aggressive replenishment."""
while self._running:
try:
current_size = self._sandbox_queue.qsize()
# More aggressive warmup - start warming when below 90% capacity
warmup_threshold = max(1, int(self.pool_size * 0.9))
if current_size < warmup_threshold:
needed = self.pool_size - current_size
logger.info(f"Pool size ({current_size}) below threshold ({warmup_threshold}). Warming {needed} sandboxes...")
# Create new sandboxes to fill the pool - but limit concurrent creation
max_concurrent = min(needed, 2) # Don't overwhelm Modal
tasks = []
for _ in range(max_concurrent):
task = asyncio.create_task(self._create_and_queue_sandbox())
tasks.append(task)
if tasks:
results = await asyncio.gather(*tasks, return_exceptions=True)
# Log any failures
successful = 0
for i, result in enumerate(results):
if isinstance(result, Exception):
logger.warning(f"Failed to create sandbox {i+1}/{max_concurrent}: {result}")
else:
successful += 1
if successful > 0:
logger.info(f"Successfully warmed {successful}/{max_concurrent} sandboxes")
# Adaptive sleep interval based on pool health
if current_size == 0:
# Critical: no sandboxes available
sleep_interval = 1
elif current_size < warmup_threshold:
# Low: need more sandboxes
sleep_interval = 2
else:
# Healthy: normal monitoring
sleep_interval = 5
await asyncio.sleep(sleep_interval)
except Exception as e:
logger.error(f"Error in warmup loop: {e}")
await asyncio.sleep(10) # Wait longer on error
async def _create_and_queue_sandbox(self):
"""Create a sandbox and add it to the queue."""
start_time = time.time()
try:
# Create the sandbox
sandbox = await self._create_sandbox()
creation_time = time.time() - start_time
logger.info(f"Sandbox creation took {creation_time:.2f}s")
# Proactively warm up the sandbox with core imports
warmup_start = time.time()
await self._warmup_sandbox_imports(sandbox)
warmup_time = time.time() - warmup_start
logger.info(f"Sandbox warmup with imports took {warmup_time:.2f}s")
pooled_sb = PooledSandbox(
sandbox=sandbox,
created_at=time.time(),
last_used=time.time()
)
try:
self._sandbox_queue.put_nowait(pooled_sb)
total_time = time.time() - start_time
logger.info(f"Added warm sandbox to pool (total time: {total_time:.2f}s)")
except asyncio.QueueFull:
# Pool is full, terminate this sandbox
await self._terminate_sandbox(sandbox)
except Exception as e:
total_time = time.time() - start_time
logger.error(f"Failed to create and queue sandbox after {total_time:.2f}s: {e}")
async def _warmup_sandbox_imports(self, sandbox: modal.Sandbox):
"""Warm up sandbox by importing core packages."""
try:
from mcp_hub.package_utils import get_warmup_import_commands
# Get warmup commands
import_commands = get_warmup_import_commands()
warmup_script = "; ".join(import_commands)
# Execute the warmup script
logger.debug("Running sandbox warmup imports...")
proc = await asyncio.get_event_loop().run_in_executor(
None,
lambda: sandbox.exec("python", "-c", warmup_script, timeout=30)
)
# Check if warmup was successful
if hasattr(proc, 'stdout') and hasattr(proc.stdout, 'read'):
output = proc.stdout.read()
if "Core packages warmed up successfully" in output:
logger.debug("Sandbox warmup imports completed successfully")
else:
logger.warning(f"Sandbox warmup completed but output unexpected: {output}")
else:
logger.debug("Sandbox warmup imports completed")
except Exception as e:
logger.warning(f"Failed to warm up sandbox imports (sandbox still usable): {e}")
async def _health_check_loop(self):
"""Background task to check sandbox health and perform proactive cleanup."""
while self._running:
try:
# Perform regular health checks every interval
await asyncio.sleep(self.health_check_interval)
# First do a quick proactive cleanup
cleaned = await self._proactive_cleanup()
# Then do the full health check
await self._perform_health_checks()
# If we cleaned up sandboxes, trigger warmup
if cleaned > 0:
logger.info(f"Health check cleaned {cleaned} sandboxes, pool may need warming")
except Exception as e:
logger.error(f"Error in health check loop: {e}")
await asyncio.sleep(10) # Wait longer on error
async def _perform_health_checks(self):
"""Perform health checks on sandboxes in the pool."""
# This is a simplified health check - in practice you might want
# to run a simple command to verify the sandbox is responsive
temp_sandboxes = []
# Drain the queue to check each sandbox
while not self._sandbox_queue.empty():
try:
pooled_sb = self._sandbox_queue.get_nowait()
is_healthy = await self._check_sandbox_health(pooled_sb.sandbox)
pooled_sb.health = SandboxHealth.HEALTHY if is_healthy else SandboxHealth.UNHEALTHY
if is_healthy:
temp_sandboxes.append(pooled_sb)
else:
# TERMINATE unhealthy sandbox
await self._terminate_sandbox(pooled_sb.sandbox)
self._stats["recycled"] += 1
except asyncio.QueueEmpty:
break
# Put healthy sandboxes back
for pooled_sb in temp_sandboxes:
try:
self._sandbox_queue.put_nowait(pooled_sb)
except asyncio.QueueFull:
await self._terminate_sandbox(pooled_sb.sandbox)
self._stats["health_checks"] += 1
logger.debug(f"Health check completed. Pool size: {self._sandbox_queue.qsize()}")
async def _check_sandbox_health(self, sandbox: modal.Sandbox) -> bool:
"""Check if a sandbox is healthy."""
try:
# Run a simple Python command to check if the sandbox is responsive
proc = await asyncio.get_event_loop().run_in_executor(
None,
lambda: sandbox.exec("python", "-c", "print('health_check')", timeout=5)
)
output = proc.stdout.read()
return "health_check" in output
except Exception as e:
logger.debug(f"Sandbox health check failed: {e}")
return False
async def _cleanup_loop(self):
"""Background task to cleanup old sandboxes."""
while self._running:
try:
await asyncio.sleep(30) # Check every 30 seconds
await self._cleanup_old_sandboxes()
except Exception as e:
logger.error(f"Error in cleanup loop: {e}")
async def _cleanup_old_sandboxes(self):
"""Remove old sandboxes from the pool."""
now = time.time()
temp_sandboxes = []
while not self._sandbox_queue.empty():
try:
pooled_sb = self._sandbox_queue.get_nowait()
if now - pooled_sb.created_at < self.max_age_seconds:
temp_sandboxes.append(pooled_sb)
else:
# TERMINATE expired sandbox
await self._terminate_sandbox(pooled_sb.sandbox)
self._stats["recycled"] += 1
logger.debug("Cleaned up old sandbox")
except asyncio.QueueEmpty:
break
# Put non-expired sandboxes back
for pooled_sb in temp_sandboxes:
try:
self._sandbox_queue.put_nowait(pooled_sb)
except asyncio.QueueFull:
await self._terminate_sandbox(pooled_sb.sandbox)
async def _is_sandbox_alive(self, sandbox: modal.Sandbox) -> bool:
"""Check if a sandbox is alive by running a trivial command with better error handling."""
try:
# Check if sandbox was already marked as terminated
if hasattr(sandbox, '_terminated') and sandbox._terminated:
return False
# Use a shorter timeout for liveness checks
proc = await asyncio.wait_for(
asyncio.get_event_loop().run_in_executor(
None,
lambda: sandbox.exec("python", "-c", "print('ping')", timeout=3)
),
timeout=5.0 # Overall timeout
)
if hasattr(proc, "stdout") and hasattr(proc.stdout, "read"):
out = proc.stdout.read()
return "ping" in out
else:
# For some Modal versions, output might be returned directly
out = str(proc)
return "ping" in out
except asyncio.TimeoutError:
logger.debug("Liveness check timed out - sandbox likely dead")
return False
except Exception as e:
logger.debug(f"Liveness check failed: {e}")
# Mark sandbox as dead to avoid repeated checks
if hasattr(sandbox, '_terminated'):
sandbox._terminated = True
return False
async def _emergency_pool_reset(self):
"""Emergency reset of the pool when too many consecutive failures occur."""
logger.warning("Performing emergency pool reset due to consecutive failures")
# Drain and terminate all sandboxes in the pool
terminated_count = 0
while not self._sandbox_queue.empty():
try:
pooled_sb = self._sandbox_queue.get_nowait()
await self._terminate_sandbox(pooled_sb.sandbox)
terminated_count += 1
except asyncio.QueueEmpty:
break
logger.info(f"Emergency reset: terminated {terminated_count} sandboxes")
# Reset failure counter
self._consecutive_failures = 0
# Try to create one fresh sandbox to test if the underlying issue is resolved
try:
test_sandbox = await self._create_sandbox()
test_pooled = PooledSandbox(
sandbox=test_sandbox,
created_at=time.time(),
last_used=time.time(),
use_count=0
)
self._sandbox_queue.put_nowait(test_pooled)
logger.info("Emergency reset successful: created test sandbox")
except Exception as e:
logger.error(f"Emergency reset failed to create test sandbox: {e}")
# Still reset the counter to allow retries
pass
def get_stats(self) -> Dict[str, Any]:
"""Get pool statistics including health metrics."""
return {
**self._stats,
"pool_size": self._sandbox_queue.qsize(),
"target_pool_size": self.pool_size,
"running": self._running,
"consecutive_failures": self._consecutive_failures,
"last_successful_creation": self._last_successful_creation,
"time_since_last_success": time.time() - self._last_successful_creation,
"health_status": "healthy" if self._consecutive_failures < 3 else "degraded" if self._consecutive_failures < self._pool_reset_threshold else "critical"
}
async def _proactive_cleanup(self):
"""Proactively clean up dead or unhealthy sandboxes from the pool."""
temp_sandboxes = []
cleaned_count = 0
# Drain the queue to check each sandbox
while not self._sandbox_queue.empty():
try:
pooled_sb = self._sandbox_queue.get_nowait()
# Quick health check
if await self._is_sandbox_alive(pooled_sb.sandbox):
# Sandbox is alive - keep it
temp_sandboxes.append(pooled_sb)
else:
# Sandbox is dead - terminate it
await self._terminate_sandbox(pooled_sb.sandbox)
cleaned_count += 1
logger.debug("Cleaned up dead sandbox during proactive cleanup")
except asyncio.QueueEmpty:
break
# Put healthy sandboxes back
for pooled_sb in temp_sandboxes:
try:
self._sandbox_queue.put_nowait(pooled_sb)
except asyncio.QueueFull:
# Shouldn't happen, but terminate if it does
await self._terminate_sandbox(pooled_sb.sandbox)
cleaned_count += 1
if cleaned_count > 0:
logger.info(f"Proactive cleanup removed {cleaned_count} dead sandboxes")
return cleaned_count
# Helper function for testing and debugging the sandbox pool
async def test_sandbox_pool_health(pool: WarmSandboxPool) -> Dict[str, Any]:
"""Test sandbox pool health and return detailed diagnostics."""
diagnostics: Dict[str, Any] = {
"timestamp": time.time(),
"pool_stats": pool.get_stats(),
"tests": {}
}
logger.info("Starting sandbox pool health test...")
# Test 1: Pool basic stats
stats = pool.get_stats()
diagnostics["tests"]["pool_stats"] = {
"passed": True,
"details": stats
}
# Test 2: Try to get a sandbox
try:
async with pool.get_sandbox(timeout=10.0) as sandbox:
# Test 3: Try to run a simple command
try:
proc = await asyncio.get_event_loop().run_in_executor(
None,
lambda: sandbox.exec("python", "-c", "print('health_test_ok')", timeout=5)
)
output = proc.stdout.read() if hasattr(proc.stdout, "read") else str(proc)
diagnostics["tests"]["sandbox_execution"] = {
"passed": "health_test_ok" in output,
"output": output[:200], # First 200 chars
"details": "Successfully executed test command"
}
except Exception as e:
diagnostics["tests"]["sandbox_execution"] = {
"passed": False,
"error": str(e),
"details": "Failed to execute test command in sandbox"
}
diagnostics["tests"]["sandbox_acquisition"] = {
"passed": True,
"details": "Successfully acquired and released sandbox"
}
except Exception as e:
diagnostics["tests"]["sandbox_acquisition"] = {
"passed": False,
"error": str(e),
"details": "Failed to acquire sandbox from pool"
}
diagnostics["tests"]["sandbox_execution"] = {
"passed": False,
"error": "Could not test - no sandbox available",
"details": "Skipped due to sandbox acquisition failure"
}
# Test 4: Check pool warmup status
if pool._running:
warmup_needed = pool.pool_size - stats["pool_size"]
diagnostics["tests"]["pool_warmup"] = {
"passed": warmup_needed <= 1, # Allow 1 sandbox to be missing
"details": f"Pool has {stats['pool_size']}/{pool.pool_size} sandboxes, {warmup_needed} needed"
}
else:
diagnostics["tests"]["pool_warmup"] = {
"passed": False,
"details": "Pool is not running"
}
# Overall health assessment
all_tests_passed = all(test.get("passed", False) for test in diagnostics["tests"].values())
diagnostics["overall_health"] = "healthy" if all_tests_passed else "unhealthy"
logger.info(f"Sandbox pool health test completed. Overall health: {diagnostics['overall_health']}")
return diagnostics