Spaces:
Running
Running
import logging | |
from typing import AsyncGenerator, List, Optional, Dict | |
from pydantic_settings import BaseSettings | |
from pydantic import PostgresDsn | |
import pg8000 | |
from pg8000 import Connection | |
from pg8000.exceptions import DatabaseError as Pg8000DatabaseError | |
import asyncio | |
from contextlib import asynccontextmanager | |
from threading import Lock | |
from urllib.parse import urlparse | |
# Set up structured logging | |
logging.basicConfig( | |
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" | |
) | |
logger = logging.getLogger(__name__) | |
class DatabaseSettings(BaseSettings): | |
db_url: PostgresDsn | |
pool_size: int = 5 # Default pool size is 5 | |
class Config: | |
env_file = ".env" | |
# Custom database errors | |
class DatabaseError(Exception): | |
"""Base exception for database errors.""" | |
pass | |
class ConnectionError(DatabaseError): | |
"""Exception raised when a database connection fails.""" | |
pass | |
class PoolExhaustedError(DatabaseError): | |
"""Exception raised when the connection pool is exhausted.""" | |
pass | |
class QueryExecutionError(DatabaseError): | |
"""Exception raised when a query execution fails.""" | |
pass | |
class HealthCheckError(DatabaseError): | |
"""Exception raised when a health check fails.""" | |
pass | |
class Database: | |
def __init__(self, db_url: PostgresDsn, pool_size: int): | |
self.db_url = db_url | |
self.pool_size = pool_size | |
self.pool: List[Connection] = [] | |
self.lock = Lock() | |
async def connect(self) -> None: | |
"""Create a connection pool.""" | |
try: | |
# Convert PostgresDsn to a string | |
db_url_str = str(self.db_url) | |
result = urlparse(db_url_str) | |
for _ in range(self.pool_size): | |
conn = pg8000.connect( | |
user=result.username, | |
password=result.password, | |
host=result.hostname, | |
port=result.port or 5432, | |
database=result.path.lstrip("/"), | |
) | |
self.pool.append(conn) | |
logger.info( | |
f"Database connection pool created with {self.pool_size} connections." | |
) | |
except Pg8000DatabaseError as e: | |
logger.error(f"Failed to create database connection pool: {e}") | |
raise ConnectionError("Failed to create database connection pool.") from e | |
async def disconnect(self) -> None: | |
"""Close all connections in the pool.""" | |
with self.lock: | |
for conn in self.pool: | |
conn.close() | |
self.pool.clear() | |
logger.info("Database connection pool closed.") | |
async def get_connection(self) -> AsyncGenerator[Connection, None]: | |
"""Acquire a connection from the pool.""" | |
with self.lock: | |
if not self.pool: | |
logger.error("Connection pool is exhausted.") | |
raise PoolExhaustedError("No available connections in the pool.") | |
conn = self.pool.pop() | |
try: | |
yield conn | |
except Pg8000DatabaseError as e: | |
logger.error(f"Connection error: {e}") | |
raise ConnectionError("Failed to use database connection.") from e | |
finally: | |
with self.lock: | |
self.pool.append(conn) | |
async def fetch(self, query: str, *args) -> Dict[str, List]: | |
""" | |
Execute a SELECT query and return the results as a dictionary of lists. | |
Args: | |
query (str): The SQL query to execute. | |
*args: Query parameters. | |
Returns: | |
Dict[str, List]: A dictionary where keys are column names and values are lists of column values. | |
Raises: | |
QueryExecutionError: If the query execution fails. | |
""" | |
try: | |
async with self.get_connection() as conn: | |
cursor = conn.cursor() | |
cursor.execute(query, args) | |
rows = cursor.fetchall() | |
columns = [desc[0] for desc in cursor.description] | |
# Convert the list of dictionaries into a dictionary of lists | |
data_dict = {column: [] for column in columns} | |
for row in rows: | |
for i, value in enumerate(row): | |
data_dict[columns[i]].append(value) | |
return data_dict | |
except Pg8000DatabaseError as e: | |
logger.error(f"Query execution failed: {e}") | |
raise QueryExecutionError(f"Failed to execute query: {query}") from e | |
async def execute(self, query: str, *args) -> None: | |
""" | |
Execute an INSERT, UPDATE, or DELETE query. | |
Args: | |
query (str): The SQL query to execute. | |
*args: Query parameters. | |
Raises: | |
QueryExecutionError: If the query execution fails. | |
""" | |
try: | |
async with self.get_connection() as conn: | |
cursor = conn.cursor() | |
cursor.execute(query, args) | |
conn.commit() | |
except Pg8000DatabaseError as e: | |
logger.error(f"Query execution failed: {e}") | |
raise QueryExecutionError(f"Failed to execute query: {query}") from e | |
async def health_check(self) -> bool: | |
""" | |
Perform a health check by executing a simple query (e.g., SELECT 1). | |
Returns: | |
bool: True if the database is healthy, False otherwise. | |
Raises: | |
HealthCheckError: If the health check fails. | |
""" | |
try: | |
async with self.get_connection() as conn: | |
cursor = conn.cursor() | |
cursor.execute("SELECT 1") | |
result = cursor.fetchone() | |
cursor.close() | |
# Check if the result is as expected | |
if result and result[0] == 1: | |
logger.info("Database health check succeeded.") | |
return True | |
else: | |
logger.error("Database health check failed: Unexpected result.") | |
raise HealthCheckError("Unexpected result from health check query.") | |
except Pg8000DatabaseError as e: | |
logger.error(f"Health check failed: {e}") | |
raise HealthCheckError("Failed to perform health check.") from e | |
# Dependency to get the database instance | |
async def get_db() -> AsyncGenerator[Database, None]: | |
settings = DatabaseSettings() | |
db = Database(db_url=settings.db_url, pool_size=settings.pool_size) | |
await db.connect() | |
try: | |
yield db | |
finally: | |
await db.disconnect() | |
# Example usage | |
if __name__ == "__main__": | |
async def main(): | |
settings = DatabaseSettings() | |
db = Database(db_url=settings.db_url, pool_size=settings.pool_size) | |
await db.connect() | |
try: | |
# Perform a health check | |
is_healthy = await db.health_check() | |
print(f"Database health check: {'Success' if is_healthy else 'Failure'}") | |
except HealthCheckError as e: | |
print(f"Health check failed: {e}") | |
finally: | |
await db.disconnect() | |
asyncio.run(main()) | |