Spaces:
Build error
Build error
import sqlite3 | |
import weakref | |
from abc import ABC, abstractmethod | |
from typing import Any, Set | |
import threading | |
from overrides import override | |
from typing_extensions import Annotated | |
class Connection: | |
"""A threadpool connection that returns itself to the pool on close()""" | |
_pool: "Pool" | |
_db_file: str | |
_conn: sqlite3.Connection | |
def __init__( | |
self, pool: "Pool", db_file: str, is_uri: bool, *args: Any, **kwargs: Any | |
): | |
self._pool = pool | |
self._db_file = db_file | |
self._conn = sqlite3.connect( | |
db_file, timeout=1000, check_same_thread=False, uri=is_uri, *args, **kwargs | |
) # type: ignore | |
self._conn.isolation_level = None # Handle commits explicitly | |
def execute(self, sql: str, parameters=...) -> sqlite3.Cursor: # type: ignore | |
if parameters is ...: | |
return self._conn.execute(sql) | |
return self._conn.execute(sql, parameters) | |
def commit(self) -> None: | |
self._conn.commit() | |
def rollback(self) -> None: | |
self._conn.rollback() | |
def cursor(self) -> sqlite3.Cursor: | |
return self._conn.cursor() | |
def close_actual(self) -> None: | |
"""Actually closes the connection to the db""" | |
self._conn.close() | |
class Pool(ABC): | |
"""Abstract base class for a pool of connections to a sqlite database.""" | |
def __init__(self, db_file: str, is_uri: bool) -> None: | |
pass | |
def connect(self, *args: Any, **kwargs: Any) -> Connection: | |
"""Return a connection from the pool.""" | |
pass | |
def close(self) -> None: | |
"""Close all connections in the pool.""" | |
pass | |
def return_to_pool(self, conn: Connection) -> None: | |
"""Return a connection to the pool.""" | |
pass | |
class LockPool(Pool): | |
"""A pool that has a single connection per thread but uses a lock to ensure that only one thread can use it at a time. | |
This is used because sqlite does not support multithreaded access with connection timeouts when using the | |
shared cache mode. We use the shared cache mode to allow multiple threads to share a database. | |
""" | |
_connections: Set[Annotated[weakref.ReferenceType, Connection]] | |
_lock: threading.RLock | |
_connection: threading.local | |
_db_file: str | |
_is_uri: bool | |
def __init__(self, db_file: str, is_uri: bool = False): | |
self._connections = set() | |
self._connection = threading.local() | |
self._lock = threading.RLock() | |
self._db_file = db_file | |
self._is_uri = is_uri | |
def connect(self, *args: Any, **kwargs: Any) -> Connection: | |
self._lock.acquire() | |
if hasattr(self._connection, "conn") and self._connection.conn is not None: | |
return self._connection.conn # type: ignore # cast doesn't work here for some reason | |
else: | |
new_connection = Connection( | |
self, self._db_file, self._is_uri, *args, **kwargs | |
) | |
self._connection.conn = new_connection | |
self._connections.add(weakref.ref(new_connection)) | |
return new_connection | |
def return_to_pool(self, conn: Connection) -> None: | |
try: | |
self._lock.release() | |
except RuntimeError: | |
pass | |
def close(self) -> None: | |
for conn in self._connections: | |
if conn() is not None: | |
conn().close_actual() # type: ignore | |
self._connections.clear() | |
self._connection = threading.local() | |
try: | |
self._lock.release() | |
except RuntimeError: | |
pass | |
class PerThreadPool(Pool): | |
"""Maintains a connection per thread. For now this does not maintain a cap on the number of connections, but it could be | |
extended to do so and block on connect() if the cap is reached. | |
""" | |
_connections: Set[Annotated[weakref.ReferenceType, Connection]] | |
_lock: threading.Lock | |
_connection: threading.local | |
_db_file: str | |
_is_uri_: bool | |
def __init__(self, db_file: str, is_uri: bool = False): | |
self._connections = set() | |
self._connection = threading.local() | |
self._lock = threading.Lock() | |
self._db_file = db_file | |
self._is_uri = is_uri | |
def connect(self, *args: Any, **kwargs: Any) -> Connection: | |
if hasattr(self._connection, "conn") and self._connection.conn is not None: | |
return self._connection.conn # type: ignore # cast doesn't work here for some reason | |
else: | |
new_connection = Connection( | |
self, self._db_file, self._is_uri, *args, **kwargs | |
) | |
self._connection.conn = new_connection | |
with self._lock: | |
self._connections.add(weakref.ref(new_connection)) | |
return new_connection | |
def close(self) -> None: | |
with self._lock: | |
for conn in self._connections: | |
if conn() is not None: | |
conn().close_actual() # type: ignore | |
self._connections.clear() | |
self._connection = threading.local() | |
def return_to_pool(self, conn: Connection) -> None: | |
pass # Each thread gets its own connection, so we don't need to return it to the pool | |