|
import sqlite3 |
|
import weakref |
|
from abc import ABC, abstractmethod |
|
from typing import Any, Set |
|
import threading |
|
from overrides import override |
|
from typing_extensions import Annotated |
|
|
|
|
|
class Connection: |
|
"""A threadpool connection that returns itself to the pool on close()""" |
|
|
|
_pool: "Pool" |
|
_db_file: str |
|
_conn: sqlite3.Connection |
|
|
|
def __init__( |
|
self, pool: "Pool", db_file: str, is_uri: bool, *args: Any, **kwargs: Any |
|
): |
|
self._pool = pool |
|
self._db_file = db_file |
|
self._conn = sqlite3.connect( |
|
db_file, timeout=1000, check_same_thread=False, uri=is_uri, *args, **kwargs |
|
) |
|
self._conn.isolation_level = None |
|
|
|
def execute(self, sql: str, parameters=...) -> sqlite3.Cursor: |
|
if parameters is ...: |
|
return self._conn.execute(sql) |
|
return self._conn.execute(sql, parameters) |
|
|
|
def commit(self) -> None: |
|
self._conn.commit() |
|
|
|
def rollback(self) -> None: |
|
self._conn.rollback() |
|
|
|
def cursor(self) -> sqlite3.Cursor: |
|
return self._conn.cursor() |
|
|
|
def close_actual(self) -> None: |
|
"""Actually closes the connection to the db""" |
|
self._conn.close() |
|
|
|
|
|
class Pool(ABC): |
|
"""Abstract base class for a pool of connections to a sqlite database.""" |
|
|
|
@abstractmethod |
|
def __init__(self, db_file: str, is_uri: bool) -> None: |
|
pass |
|
|
|
@abstractmethod |
|
def connect(self, *args: Any, **kwargs: Any) -> Connection: |
|
"""Return a connection from the pool.""" |
|
pass |
|
|
|
@abstractmethod |
|
def close(self) -> None: |
|
"""Close all connections in the pool.""" |
|
pass |
|
|
|
@abstractmethod |
|
def return_to_pool(self, conn: Connection) -> None: |
|
"""Return a connection to the pool.""" |
|
pass |
|
|
|
|
|
class LockPool(Pool): |
|
"""A pool that has a single connection per thread but uses a lock to ensure that only one thread can use it at a time. |
|
This is used because sqlite does not support multithreaded access with connection timeouts when using the |
|
shared cache mode. We use the shared cache mode to allow multiple threads to share a database. |
|
""" |
|
|
|
_connections: Set[Annotated[weakref.ReferenceType, Connection]] |
|
_lock: threading.RLock |
|
_connection: threading.local |
|
_db_file: str |
|
_is_uri: bool |
|
|
|
def __init__(self, db_file: str, is_uri: bool = False): |
|
self._connections = set() |
|
self._connection = threading.local() |
|
self._lock = threading.RLock() |
|
self._db_file = db_file |
|
self._is_uri = is_uri |
|
|
|
@override |
|
def connect(self, *args: Any, **kwargs: Any) -> Connection: |
|
self._lock.acquire() |
|
if hasattr(self._connection, "conn") and self._connection.conn is not None: |
|
return self._connection.conn |
|
else: |
|
new_connection = Connection( |
|
self, self._db_file, self._is_uri, *args, **kwargs |
|
) |
|
self._connection.conn = new_connection |
|
self._connections.add(weakref.ref(new_connection)) |
|
return new_connection |
|
|
|
@override |
|
def return_to_pool(self, conn: Connection) -> None: |
|
try: |
|
self._lock.release() |
|
except RuntimeError: |
|
pass |
|
|
|
@override |
|
def close(self) -> None: |
|
for conn in self._connections: |
|
if conn() is not None: |
|
conn().close_actual() |
|
self._connections.clear() |
|
self._connection = threading.local() |
|
try: |
|
self._lock.release() |
|
except RuntimeError: |
|
pass |
|
|
|
|
|
class PerThreadPool(Pool): |
|
"""Maintains a connection per thread. For now this does not maintain a cap on the number of connections, but it could be |
|
extended to do so and block on connect() if the cap is reached. |
|
""" |
|
|
|
_connections: Set[Annotated[weakref.ReferenceType, Connection]] |
|
_lock: threading.Lock |
|
_connection: threading.local |
|
_db_file: str |
|
_is_uri_: bool |
|
|
|
def __init__(self, db_file: str, is_uri: bool = False): |
|
self._connections = set() |
|
self._connection = threading.local() |
|
self._lock = threading.Lock() |
|
self._db_file = db_file |
|
self._is_uri = is_uri |
|
|
|
@override |
|
def connect(self, *args: Any, **kwargs: Any) -> Connection: |
|
if hasattr(self._connection, "conn") and self._connection.conn is not None: |
|
return self._connection.conn |
|
else: |
|
new_connection = Connection( |
|
self, self._db_file, self._is_uri, *args, **kwargs |
|
) |
|
self._connection.conn = new_connection |
|
with self._lock: |
|
self._connections.add(weakref.ref(new_connection)) |
|
return new_connection |
|
|
|
@override |
|
def close(self) -> None: |
|
with self._lock: |
|
for conn in self._connections: |
|
if conn() is not None: |
|
conn().close_actual() |
|
self._connections.clear() |
|
self._connection = threading.local() |
|
|
|
@override |
|
def return_to_pool(self, conn: Connection) -> None: |
|
pass |
|
|