Spaces:
Build error
Build error
File size: 5,372 Bytes
60e3a80 |
|
import sqlite3
import weakref
from abc import ABC, abstractmethod
from typing import Any, Set
import threading
from overrides import override
from typing_extensions import Annotated
class Connection:
"""A threadpool connection that returns itself to the pool on close()"""
_pool: "Pool"
_db_file: str
_conn: sqlite3.Connection
def __init__(
self, pool: "Pool", db_file: str, is_uri: bool, *args: Any, **kwargs: Any
):
self._pool = pool
self._db_file = db_file
self._conn = sqlite3.connect(
db_file, timeout=1000, check_same_thread=False, uri=is_uri, *args, **kwargs
) # type: ignore
self._conn.isolation_level = None # Handle commits explicitly
def execute(self, sql: str, parameters=...) -> sqlite3.Cursor: # type: ignore
if parameters is ...:
return self._conn.execute(sql)
return self._conn.execute(sql, parameters)
def commit(self) -> None:
self._conn.commit()
def rollback(self) -> None:
self._conn.rollback()
def cursor(self) -> sqlite3.Cursor:
return self._conn.cursor()
def close_actual(self) -> None:
"""Actually closes the connection to the db"""
self._conn.close()
class Pool(ABC):
"""Abstract base class for a pool of connections to a sqlite database."""
@abstractmethod
def __init__(self, db_file: str, is_uri: bool) -> None:
pass
@abstractmethod
def connect(self, *args: Any, **kwargs: Any) -> Connection:
"""Return a connection from the pool."""
pass
@abstractmethod
def close(self) -> None:
"""Close all connections in the pool."""
pass
@abstractmethod
def return_to_pool(self, conn: Connection) -> None:
"""Return a connection to the pool."""
pass
class LockPool(Pool):
"""A pool that has a single connection per thread but uses a lock to ensure that only one thread can use it at a time.
This is used because sqlite does not support multithreaded access with connection timeouts when using the
shared cache mode. We use the shared cache mode to allow multiple threads to share a database.
"""
_connections: Set[Annotated[weakref.ReferenceType, Connection]]
_lock: threading.RLock
_connection: threading.local
_db_file: str
_is_uri: bool
def __init__(self, db_file: str, is_uri: bool = False):
self._connections = set()
self._connection = threading.local()
self._lock = threading.RLock()
self._db_file = db_file
self._is_uri = is_uri
@override
def connect(self, *args: Any, **kwargs: Any) -> Connection:
self._lock.acquire()
if hasattr(self._connection, "conn") and self._connection.conn is not None:
return self._connection.conn # type: ignore # cast doesn't work here for some reason
else:
new_connection = Connection(
self, self._db_file, self._is_uri, *args, **kwargs
)
self._connection.conn = new_connection
self._connections.add(weakref.ref(new_connection))
return new_connection
@override
def return_to_pool(self, conn: Connection) -> None:
try:
self._lock.release()
except RuntimeError:
pass
@override
def close(self) -> None:
for conn in self._connections:
if conn() is not None:
conn().close_actual() # type: ignore
self._connections.clear()
self._connection = threading.local()
try:
self._lock.release()
except RuntimeError:
pass
class PerThreadPool(Pool):
"""Maintains a connection per thread. For now this does not maintain a cap on the number of connections, but it could be
extended to do so and block on connect() if the cap is reached.
"""
_connections: Set[Annotated[weakref.ReferenceType, Connection]]
_lock: threading.Lock
_connection: threading.local
_db_file: str
_is_uri_: bool
def __init__(self, db_file: str, is_uri: bool = False):
self._connections = set()
self._connection = threading.local()
self._lock = threading.Lock()
self._db_file = db_file
self._is_uri = is_uri
@override
def connect(self, *args: Any, **kwargs: Any) -> Connection:
if hasattr(self._connection, "conn") and self._connection.conn is not None:
return self._connection.conn # type: ignore # cast doesn't work here for some reason
else:
new_connection = Connection(
self, self._db_file, self._is_uri, *args, **kwargs
)
self._connection.conn = new_connection
with self._lock:
self._connections.add(weakref.ref(new_connection))
return new_connection
@override
def close(self) -> None:
with self._lock:
for conn in self._connections:
if conn() is not None:
conn().close_actual() # type: ignore
self._connections.clear()
self._connection = threading.local()
@override
def return_to_pool(self, conn: Connection) -> None:
pass # Each thread gets its own connection, so we don't need to return it to the pool
|