|
from typing import Any, Optional, Sequence, Tuple, Type, Union |
|
from types import TracebackType |
|
from typing_extensions import Protocol, Self, Literal |
|
from abc import ABC, abstractmethod |
|
from threading import local |
|
from overrides import override, EnforceOverrides |
|
import pypika |
|
import pypika.queries |
|
from chromadb.config import System, Component |
|
from uuid import UUID |
|
from itertools import islice, count |
|
|
|
from chromadb.types import SeqId |
|
|
|
|
|
class UniqueConstraintError(Exception): |
|
"""Raised when an insert operation would violate a unique constraint""" |
|
|
|
pass |
|
|
|
|
|
class Cursor(Protocol): |
|
"""Reifies methods we use from a DBAPI2 Cursor since DBAPI2 is not typed.""" |
|
|
|
def execute(self, sql: str, params: Optional[Tuple[Any, ...]] = None) -> Self: |
|
... |
|
|
|
def executescript(self, script: str) -> Self: |
|
... |
|
|
|
def executemany( |
|
self, sql: str, params: Optional[Sequence[Tuple[Any, ...]]] = None |
|
) -> Self: |
|
... |
|
|
|
def fetchone(self) -> Tuple[Any, ...]: |
|
... |
|
|
|
def fetchall(self) -> Sequence[Tuple[Any, ...]]: |
|
... |
|
|
|
|
|
class TxWrapper(ABC, EnforceOverrides): |
|
"""Wrapper class for DBAPI 2.0 Connection objects, with which clients can implement transactions. |
|
Makes two guarantees that basic DBAPI 2.0 connections do not: |
|
|
|
- __enter__ returns a Cursor object consistently (instead of a Connection like some do) |
|
- Always re-raises an exception if one was thrown from the body |
|
""" |
|
|
|
@abstractmethod |
|
def __enter__(self) -> Cursor: |
|
pass |
|
|
|
@abstractmethod |
|
def __exit__( |
|
self, |
|
exc_type: Optional[Type[BaseException]], |
|
exc_value: Optional[BaseException], |
|
traceback: Optional[TracebackType], |
|
) -> Literal[False]: |
|
pass |
|
|
|
|
|
class SqlDB(Component): |
|
"""DBAPI 2.0 interface wrapper to ensure consistent behavior between implementations""" |
|
|
|
def __init__(self, system: System): |
|
super().__init__(system) |
|
|
|
@abstractmethod |
|
def tx(self) -> TxWrapper: |
|
"""Return a transaction wrapper""" |
|
pass |
|
|
|
@staticmethod |
|
@abstractmethod |
|
def querybuilder() -> Type[pypika.Query]: |
|
"""Return a PyPika Query builder of an appropriate subtype for this database |
|
implementation (see |
|
https://pypika.readthedocs.io/en/latest/3_advanced.html#handling-different-database-platforms) |
|
""" |
|
pass |
|
|
|
@staticmethod |
|
@abstractmethod |
|
def parameter_format() -> str: |
|
"""Return the appropriate parameter format for this database implementation. |
|
Will be called with str.format(i) where i is the numeric index of the parameter. |
|
""" |
|
pass |
|
|
|
@staticmethod |
|
@abstractmethod |
|
def uuid_to_db(uuid: Optional[UUID]) -> Optional[Any]: |
|
"""Convert a UUID to a value that can be passed to the DB driver""" |
|
pass |
|
|
|
@staticmethod |
|
@abstractmethod |
|
def uuid_from_db(value: Optional[Any]) -> Optional[UUID]: |
|
"""Convert a value from the DB driver to a UUID""" |
|
pass |
|
|
|
@staticmethod |
|
@abstractmethod |
|
def unique_constraint_error() -> Type[BaseException]: |
|
"""Return the exception type that the DB raises when a unique constraint is |
|
violated""" |
|
pass |
|
|
|
def param(self, idx: int) -> pypika.Parameter: |
|
"""Return a PyPika Parameter object for the given index""" |
|
return pypika.Parameter(self.parameter_format().format(idx)) |
|
|
|
@staticmethod |
|
def decode_seq_id(seq_id_bytes: Union[bytes, int]) -> SeqId: |
|
"""Decode a byte array into a SeqID""" |
|
if isinstance(seq_id_bytes, int): |
|
return seq_id_bytes |
|
|
|
if len(seq_id_bytes) == 8: |
|
return int.from_bytes(seq_id_bytes, "big") |
|
elif len(seq_id_bytes) == 24: |
|
return int.from_bytes(seq_id_bytes, "big") |
|
else: |
|
raise ValueError(f"Unknown SeqID type with length {len(seq_id_bytes)}") |
|
|
|
@staticmethod |
|
def encode_seq_id(seq_id: SeqId) -> bytes: |
|
"""Encode a SeqID into a byte array""" |
|
if seq_id.bit_length() <= 64: |
|
return int.to_bytes(seq_id, 8, "big") |
|
elif seq_id.bit_length() <= 192: |
|
return int.to_bytes(seq_id, 24, "big") |
|
else: |
|
raise ValueError(f"Unsupported SeqID: {seq_id}") |
|
|
|
|
|
_context = local() |
|
|
|
|
|
class ParameterValue(pypika.Parameter): |
|
""" |
|
Wrapper class for PyPika paramters that allows the values for Parameters |
|
to be expressed inline while building a query. See get_sql() for |
|
detailed usage information. |
|
""" |
|
|
|
def __init__(self, value: Any): |
|
self.value = value |
|
|
|
@override |
|
def get_sql(self, **kwargs: Any) -> str: |
|
if isinstance(self.value, (list, tuple)): |
|
_context.values.extend(self.value) |
|
indexes = islice(_context.generator, len(self.value)) |
|
placeholders = ", ".join(_context.formatstr.format(i) for i in indexes) |
|
val = f"({placeholders})" |
|
else: |
|
_context.values.append(self.value) |
|
val = _context.formatstr.format(next(_context.generator)) |
|
|
|
return str(val) |
|
|
|
|
|
def get_sql( |
|
query: pypika.queries.QueryBuilder, formatstr: str = "?" |
|
) -> Tuple[str, Tuple[Any, ...]]: |
|
""" |
|
Wrapper for pypika's get_sql method that allows the values for Parameters |
|
to be expressed inline while building a query, and that returns a tuple of the |
|
SQL string and parameters. This makes it easier to construct complex queries |
|
programmatically and automatically matches up the generated SQL with the required |
|
parameter vector. |
|
|
|
Doing so requires using the ParameterValue class defined in this module instead |
|
of the base pypika.Parameter class. |
|
|
|
Usage Example: |
|
|
|
q = ( |
|
pypika.Query().from_("table") |
|
.select("col1") |
|
.where("col2"==ParameterValue("foo")) |
|
.where("col3"==ParameterValue("bar")) |
|
) |
|
|
|
sql, params = get_sql(q) |
|
|
|
cursor.execute(sql, params) |
|
|
|
Note how it is not necessary to construct the parameter vector manually... it |
|
will always be generated with the parameter values in the same order as emitted |
|
SQL string. |
|
|
|
The format string should match the parameter format for the database being used. |
|
It will be called with str.format(i) where i is the numeric index of the parameter. |
|
For example, Postgres requires parameters like `:1`, `:2`, etc. so the format string |
|
should be `":{}"`. |
|
|
|
See https://pypika.readthedocs.io/en/latest/2_tutorial.html#parametrized-queries for more |
|
information on parameterized queries in PyPika. |
|
""" |
|
|
|
_context.values = [] |
|
_context.generator = count(1) |
|
_context.formatstr = formatstr |
|
sql = query.get_sql() |
|
params = tuple(_context.values) |
|
return sql, params |
|
|