Spaces:
Runtime error
Runtime error
import os, sys | |
from pathlib import Path | |
import aiosqlite | |
import asyncio | |
from typing import Optional, Tuple, Dict | |
from contextlib import asynccontextmanager | |
import logging | |
import json # Added for serialization/deserialization | |
from .utils import ensure_content_dirs, generate_content_hash | |
from .models import CrawlResult, MarkdownGenerationResult | |
import xxhash | |
import aiofiles | |
from .config import NEED_MIGRATION | |
from .version_manager import VersionManager | |
from .async_logger import AsyncLogger | |
from .utils import get_error_context, create_box_message | |
# Set up logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
base_directory = DB_PATH = os.path.join(os.getenv("CRAWL4_AI_BASE_DIRECTORY", Path.home()), ".crawl4ai") | |
os.makedirs(DB_PATH, exist_ok=True) | |
DB_PATH = os.path.join(base_directory, "crawl4ai.db") | |
class AsyncDatabaseManager: | |
def __init__(self, pool_size: int = 10, max_retries: int = 3): | |
self.db_path = DB_PATH | |
self.content_paths = ensure_content_dirs(os.path.dirname(DB_PATH)) | |
self.pool_size = pool_size | |
self.max_retries = max_retries | |
self.connection_pool: Dict[int, aiosqlite.Connection] = {} | |
self.pool_lock = asyncio.Lock() | |
self.init_lock = asyncio.Lock() | |
self.connection_semaphore = asyncio.Semaphore(pool_size) | |
self._initialized = False | |
self.version_manager = VersionManager() | |
self.logger = AsyncLogger( | |
log_file=os.path.join(base_directory, ".crawl4ai", "crawler_db.log"), | |
verbose=False, | |
tag_width=10 | |
) | |
async def initialize(self): | |
"""Initialize the database and connection pool""" | |
try: | |
self.logger.info("Initializing database", tag="INIT") | |
# Ensure the database file exists | |
os.makedirs(os.path.dirname(self.db_path), exist_ok=True) | |
# Check if version update is needed | |
needs_update = self.version_manager.needs_update() | |
# Always ensure base table exists | |
await self.ainit_db() | |
# Verify the table exists | |
async with aiosqlite.connect(self.db_path, timeout=30.0) as db: | |
async with db.execute( | |
"SELECT name FROM sqlite_master WHERE type='table' AND name='crawled_data'" | |
) as cursor: | |
result = await cursor.fetchone() | |
if not result: | |
raise Exception("crawled_data table was not created") | |
# If version changed or fresh install, run updates | |
if needs_update: | |
self.logger.info("New version detected, running updates", tag="INIT") | |
await self.update_db_schema() | |
from .migrations import run_migration # Import here to avoid circular imports | |
await run_migration() | |
self.version_manager.update_version() # Update stored version after successful migration | |
self.logger.success("Version update completed successfully", tag="COMPLETE") | |
else: | |
self.logger.success("Database initialization completed successfully", tag="COMPLETE") | |
except Exception as e: | |
self.logger.error( | |
message="Database initialization error: {error}", | |
tag="ERROR", | |
params={"error": str(e)} | |
) | |
self.logger.info( | |
message="Database will be initialized on first use", | |
tag="INIT" | |
) | |
raise | |
async def cleanup(self): | |
"""Cleanup connections when shutting down""" | |
async with self.pool_lock: | |
for conn in self.connection_pool.values(): | |
await conn.close() | |
self.connection_pool.clear() | |
async def get_connection(self): | |
"""Connection pool manager with enhanced error handling""" | |
if not self._initialized: | |
async with self.init_lock: | |
if not self._initialized: | |
try: | |
await self.initialize() | |
self._initialized = True | |
except Exception as e: | |
import sys | |
error_context = get_error_context(sys.exc_info()) | |
self.logger.error( | |
message="Database initialization failed:\n{error}\n\nContext:\n{context}\n\nTraceback:\n{traceback}", | |
tag="ERROR", | |
force_verbose=True, | |
params={ | |
"error": str(e), | |
"context": error_context["code_context"], | |
"traceback": error_context["full_traceback"] | |
} | |
) | |
raise | |
await self.connection_semaphore.acquire() | |
task_id = id(asyncio.current_task()) | |
try: | |
async with self.pool_lock: | |
if task_id not in self.connection_pool: | |
try: | |
conn = await aiosqlite.connect( | |
self.db_path, | |
timeout=30.0 | |
) | |
await conn.execute('PRAGMA journal_mode = WAL') | |
await conn.execute('PRAGMA busy_timeout = 5000') | |
# Verify database structure | |
async with conn.execute("PRAGMA table_info(crawled_data)") as cursor: | |
columns = await cursor.fetchall() | |
column_names = [col[1] for col in columns] | |
expected_columns = { | |
'url', 'html', 'cleaned_html', 'markdown', 'extracted_content', | |
'success', 'media', 'links', 'metadata', 'screenshot', | |
'response_headers', 'downloaded_files' | |
} | |
missing_columns = expected_columns - set(column_names) | |
if missing_columns: | |
raise ValueError(f"Database missing columns: {missing_columns}") | |
self.connection_pool[task_id] = conn | |
except Exception as e: | |
import sys | |
error_context = get_error_context(sys.exc_info()) | |
error_message = ( | |
f"Unexpected error in db get_connection at line {error_context['line_no']} " | |
f"in {error_context['function']} ({error_context['filename']}):\n" | |
f"Error: {str(e)}\n\n" | |
f"Code context:\n{error_context['code_context']}" | |
) | |
self.logger.error( | |
message=create_box_message(error_message, type= "error"), | |
) | |
raise | |
yield self.connection_pool[task_id] | |
except Exception as e: | |
import sys | |
error_context = get_error_context(sys.exc_info()) | |
error_message = ( | |
f"Unexpected error in db get_connection at line {error_context['line_no']} " | |
f"in {error_context['function']} ({error_context['filename']}):\n" | |
f"Error: {str(e)}\n\n" | |
f"Code context:\n{error_context['code_context']}" | |
) | |
self.logger.error( | |
message=create_box_message(error_message, type= "error"), | |
) | |
raise | |
finally: | |
async with self.pool_lock: | |
if task_id in self.connection_pool: | |
await self.connection_pool[task_id].close() | |
del self.connection_pool[task_id] | |
self.connection_semaphore.release() | |
async def execute_with_retry(self, operation, *args): | |
"""Execute database operations with retry logic""" | |
for attempt in range(self.max_retries): | |
try: | |
async with self.get_connection() as db: | |
result = await operation(db, *args) | |
await db.commit() | |
return result | |
except Exception as e: | |
if attempt == self.max_retries - 1: | |
self.logger.error( | |
message="Operation failed after {retries} attempts: {error}", | |
tag="ERROR", | |
force_verbose=True, | |
params={ | |
"retries": self.max_retries, | |
"error": str(e) | |
} | |
) | |
raise | |
await asyncio.sleep(1 * (attempt + 1)) # Exponential backoff | |
async def ainit_db(self): | |
"""Initialize database schema""" | |
async with aiosqlite.connect(self.db_path, timeout=30.0) as db: | |
await db.execute(''' | |
CREATE TABLE IF NOT EXISTS crawled_data ( | |
url TEXT PRIMARY KEY, | |
html TEXT, | |
cleaned_html TEXT, | |
markdown TEXT, | |
extracted_content TEXT, | |
success BOOLEAN, | |
media TEXT DEFAULT "{}", | |
links TEXT DEFAULT "{}", | |
metadata TEXT DEFAULT "{}", | |
screenshot TEXT DEFAULT "", | |
response_headers TEXT DEFAULT "{}", | |
downloaded_files TEXT DEFAULT "{}" -- New column added | |
) | |
''') | |
await db.commit() | |
async def update_db_schema(self): | |
"""Update database schema if needed""" | |
async with aiosqlite.connect(self.db_path, timeout=30.0) as db: | |
cursor = await db.execute("PRAGMA table_info(crawled_data)") | |
columns = await cursor.fetchall() | |
column_names = [column[1] for column in columns] | |
# List of new columns to add | |
new_columns = ['media', 'links', 'metadata', 'screenshot', 'response_headers', 'downloaded_files'] | |
for column in new_columns: | |
if column not in column_names: | |
await self.aalter_db_add_column(column, db) | |
await db.commit() | |
async def aalter_db_add_column(self, new_column: str, db): | |
"""Add new column to the database""" | |
if new_column == 'response_headers': | |
await db.execute(f'ALTER TABLE crawled_data ADD COLUMN {new_column} TEXT DEFAULT "{{}}"') | |
else: | |
await db.execute(f'ALTER TABLE crawled_data ADD COLUMN {new_column} TEXT DEFAULT ""') | |
self.logger.info( | |
message="Added column '{column}' to the database", | |
tag="INIT", | |
params={"column": new_column} | |
) | |
async def aget_cached_url(self, url: str) -> Optional[CrawlResult]: | |
"""Retrieve cached URL data as CrawlResult""" | |
async def _get(db): | |
async with db.execute( | |
'SELECT * FROM crawled_data WHERE url = ?', (url,) | |
) as cursor: | |
row = await cursor.fetchone() | |
if not row: | |
return None | |
# Get column names | |
columns = [description[0] for description in cursor.description] | |
# Create dict from row data | |
row_dict = dict(zip(columns, row)) | |
# Load content from files using stored hashes | |
content_fields = { | |
'html': row_dict['html'], | |
'cleaned_html': row_dict['cleaned_html'], | |
'markdown': row_dict['markdown'], | |
'extracted_content': row_dict['extracted_content'], | |
'screenshot': row_dict['screenshot'], | |
'screenshots': row_dict['screenshot'], | |
} | |
for field, hash_value in content_fields.items(): | |
if hash_value: | |
content = await self._load_content( | |
hash_value, | |
field.split('_')[0] # Get content type from field name | |
) | |
row_dict[field] = content or "" | |
else: | |
row_dict[field] = "" | |
# Parse JSON fields | |
json_fields = ['media', 'links', 'metadata', 'response_headers', 'markdown'] | |
for field in json_fields: | |
try: | |
row_dict[field] = json.loads(row_dict[field]) if row_dict[field] else {} | |
except json.JSONDecodeError: | |
row_dict[field] = {} | |
if isinstance(row_dict['markdown'], Dict): | |
row_dict['markdown_v2'] = row_dict['markdown'] | |
if row_dict['markdown'].get('raw_markdown'): | |
row_dict['markdown'] = row_dict['markdown']['raw_markdown'] | |
# Parse downloaded_files | |
try: | |
row_dict['downloaded_files'] = json.loads(row_dict['downloaded_files']) if row_dict['downloaded_files'] else [] | |
except json.JSONDecodeError: | |
row_dict['downloaded_files'] = [] | |
# Remove any fields not in CrawlResult model | |
valid_fields = CrawlResult.__annotations__.keys() | |
filtered_dict = {k: v for k, v in row_dict.items() if k in valid_fields} | |
return CrawlResult(**filtered_dict) | |
try: | |
return await self.execute_with_retry(_get) | |
except Exception as e: | |
self.logger.error( | |
message="Error retrieving cached URL: {error}", | |
tag="ERROR", | |
force_verbose=True, | |
params={"error": str(e)} | |
) | |
return None | |
async def acache_url(self, result: CrawlResult): | |
"""Cache CrawlResult data""" | |
# Store content files and get hashes | |
content_map = { | |
'html': (result.html, 'html'), | |
'cleaned_html': (result.cleaned_html or "", 'cleaned'), | |
'markdown': None, | |
'extracted_content': (result.extracted_content or "", 'extracted'), | |
'screenshot': (result.screenshot or "", 'screenshots') | |
} | |
try: | |
if isinstance(result.markdown, MarkdownGenerationResult): | |
content_map['markdown'] = (result.markdown.model_dump_json(), 'markdown') | |
elif hasattr(result, 'markdown_v2'): | |
content_map['markdown'] = (result.markdown_v2.model_dump_json(), 'markdown') | |
elif isinstance(result.markdown, str): | |
markdown_result = MarkdownGenerationResult(raw_markdown=result.markdown) | |
content_map['markdown'] = (markdown_result.model_dump_json(), 'markdown') | |
else: | |
content_map['markdown'] = (MarkdownGenerationResult().model_dump_json(), 'markdown') | |
except Exception as e: | |
self.logger.warning( | |
message=f"Error processing markdown content: {str(e)}", | |
tag="WARNING" | |
) | |
# Fallback to empty markdown result | |
content_map['markdown'] = (MarkdownGenerationResult().model_dump_json(), 'markdown') | |
content_hashes = {} | |
for field, (content, content_type) in content_map.items(): | |
content_hashes[field] = await self._store_content(content, content_type) | |
async def _cache(db): | |
await db.execute(''' | |
INSERT INTO crawled_data ( | |
url, html, cleaned_html, markdown, | |
extracted_content, success, media, links, metadata, | |
screenshot, response_headers, downloaded_files | |
) | |
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) | |
ON CONFLICT(url) DO UPDATE SET | |
html = excluded.html, | |
cleaned_html = excluded.cleaned_html, | |
markdown = excluded.markdown, | |
extracted_content = excluded.extracted_content, | |
success = excluded.success, | |
media = excluded.media, | |
links = excluded.links, | |
metadata = excluded.metadata, | |
screenshot = excluded.screenshot, | |
response_headers = excluded.response_headers, | |
downloaded_files = excluded.downloaded_files | |
''', ( | |
result.url, | |
content_hashes['html'], | |
content_hashes['cleaned_html'], | |
content_hashes['markdown'], | |
content_hashes['extracted_content'], | |
result.success, | |
json.dumps(result.media), | |
json.dumps(result.links), | |
json.dumps(result.metadata or {}), | |
content_hashes['screenshot'], | |
json.dumps(result.response_headers or {}), | |
json.dumps(result.downloaded_files or []) | |
)) | |
try: | |
await self.execute_with_retry(_cache) | |
except Exception as e: | |
self.logger.error( | |
message="Error caching URL: {error}", | |
tag="ERROR", | |
force_verbose=True, | |
params={"error": str(e)} | |
) | |
async def aget_total_count(self) -> int: | |
"""Get total number of cached URLs""" | |
async def _count(db): | |
async with db.execute('SELECT COUNT(*) FROM crawled_data') as cursor: | |
result = await cursor.fetchone() | |
return result[0] if result else 0 | |
try: | |
return await self.execute_with_retry(_count) | |
except Exception as e: | |
self.logger.error( | |
message="Error getting total count: {error}", | |
tag="ERROR", | |
force_verbose=True, | |
params={"error": str(e)} | |
) | |
return 0 | |
async def aclear_db(self): | |
"""Clear all data from the database""" | |
async def _clear(db): | |
await db.execute('DELETE FROM crawled_data') | |
try: | |
await self.execute_with_retry(_clear) | |
except Exception as e: | |
self.logger.error( | |
message="Error clearing database: {error}", | |
tag="ERROR", | |
force_verbose=True, | |
params={"error": str(e)} | |
) | |
async def aflush_db(self): | |
"""Drop the entire table""" | |
async def _flush(db): | |
await db.execute('DROP TABLE IF EXISTS crawled_data') | |
try: | |
await self.execute_with_retry(_flush) | |
except Exception as e: | |
self.logger.error( | |
message="Error flushing database: {error}", | |
tag="ERROR", | |
force_verbose=True, | |
params={"error": str(e)} | |
) | |
async def _store_content(self, content: str, content_type: str) -> str: | |
"""Store content in filesystem and return hash""" | |
if not content: | |
return "" | |
content_hash = generate_content_hash(content) | |
file_path = os.path.join(self.content_paths[content_type], content_hash) | |
# Only write if file doesn't exist | |
if not os.path.exists(file_path): | |
async with aiofiles.open(file_path, 'w', encoding='utf-8') as f: | |
await f.write(content) | |
return content_hash | |
async def _load_content(self, content_hash: str, content_type: str) -> Optional[str]: | |
"""Load content from filesystem by hash""" | |
if not content_hash: | |
return None | |
file_path = os.path.join(self.content_paths[content_type], content_hash) | |
try: | |
async with aiofiles.open(file_path, 'r', encoding='utf-8') as f: | |
return await f.read() | |
except: | |
self.logger.error( | |
message="Failed to load content: {file_path}", | |
tag="ERROR", | |
force_verbose=True, | |
params={"file_path": file_path} | |
) | |
return None | |
# Create a singleton instance | |
async_db_manager = AsyncDatabaseManager() | |