Spaces:
Build error
Build error
File size: 9,425 Bytes
60e3a80 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 |
import logging
from chromadb.db.impl.sqlite_pool import Connection, LockPool, PerThreadPool, Pool
from chromadb.db.migrations import MigratableDB, Migration
from chromadb.config import System, Settings
import chromadb.db.base as base
from chromadb.db.mixins.embeddings_queue import SqlEmbeddingsQueue
from chromadb.db.mixins.sysdb import SqlSysDB
from chromadb.telemetry.opentelemetry import (
OpenTelemetryClient,
OpenTelemetryGranularity,
trace_method,
)
import sqlite3
from overrides import override
import pypika
from typing import Sequence, cast, Optional, Type, Any
from typing_extensions import Literal
from types import TracebackType
import os
from uuid import UUID
from threading import local
from importlib_resources import files
from importlib_resources.abc import Traversable
logger = logging.getLogger(__name__)
class TxWrapper(base.TxWrapper):
_conn: Connection
_pool: Pool
def __init__(self, conn_pool: Pool, stack: local):
self._tx_stack = stack
self._conn = conn_pool.connect()
self._pool = conn_pool
@override
def __enter__(self) -> base.Cursor:
if len(self._tx_stack.stack) == 0:
self._conn.execute("PRAGMA case_sensitive_like = ON")
self._conn.execute("BEGIN;")
self._tx_stack.stack.append(self)
return self._conn.cursor() # type: ignore
@override
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> Literal[False]:
self._tx_stack.stack.pop()
if len(self._tx_stack.stack) == 0:
if exc_type is None:
self._conn.commit()
else:
self._conn.rollback()
self._conn.cursor().close()
self._pool.return_to_pool(self._conn)
return False
class SqliteDB(MigratableDB, SqlEmbeddingsQueue, SqlSysDB):
_conn_pool: Pool
_settings: Settings
_migration_imports: Sequence[Traversable]
_db_file: str
_tx_stack: local
_is_persistent: bool
def __init__(self, system: System):
self._settings = system.settings
self._migration_imports = [
files("chromadb.migrations.embeddings_queue"),
files("chromadb.migrations.sysdb"),
files("chromadb.migrations.metadb"),
]
self._is_persistent = self._settings.require("is_persistent")
self._opentelemetry_client = system.require(OpenTelemetryClient)
if not self._is_persistent:
# In order to allow sqlite to be shared between multiple threads, we need to use a
# URI connection string with shared cache.
# See https://www.sqlite.org/sharedcache.html
# https://stackoverflow.com/questions/3315046/sharing-a-memory-database-between-different-threads-in-python-using-sqlite3-pa
self._db_file = "file::memory:?cache=shared"
self._conn_pool = LockPool(self._db_file, is_uri=True)
else:
self._db_file = (
self._settings.require("persist_directory") + "/chroma.sqlite3"
)
if not os.path.exists(self._db_file):
os.makedirs(os.path.dirname(self._db_file), exist_ok=True)
self._conn_pool = PerThreadPool(self._db_file)
self._tx_stack = local()
super().__init__(system)
@trace_method("SqliteDB.start", OpenTelemetryGranularity.ALL)
@override
def start(self) -> None:
super().start()
with self.tx() as cur:
cur.execute("PRAGMA foreign_keys = ON")
cur.execute("PRAGMA case_sensitive_like = ON")
self.initialize_migrations()
if (
# (don't attempt to access .config if migrations haven't been run)
self._settings.require("migrations") == "apply"
and self.config.get_parameter("automatically_purge").value is False
):
logger.warn(
"⚠️ It looks like you upgraded from a version below 0.6 and could benefit from vacuuming your database. Run chromadb utils vacuum --help for more information."
)
@trace_method("SqliteDB.stop", OpenTelemetryGranularity.ALL)
@override
def stop(self) -> None:
super().stop()
self._conn_pool.close()
@staticmethod
@override
def querybuilder() -> Type[pypika.Query]:
return pypika.Query # type: ignore
@staticmethod
@override
def parameter_format() -> str:
return "?"
@staticmethod
@override
def migration_scope() -> str:
return "sqlite"
@override
def migration_dirs(self) -> Sequence[Traversable]:
return self._migration_imports
@override
def tx(self) -> TxWrapper:
if not hasattr(self._tx_stack, "stack"):
self._tx_stack.stack = []
return TxWrapper(self._conn_pool, stack=self._tx_stack)
@trace_method("SqliteDB.reset_state", OpenTelemetryGranularity.ALL)
@override
def reset_state(self) -> None:
if not self._settings.require("allow_reset"):
raise ValueError(
"Resetting the database is not allowed. Set `allow_reset` to true in the config in tests or other non-production environments where reset should be permitted."
)
with self.tx() as cur:
# Drop all tables
cur.execute(
"""
SELECT name FROM sqlite_master
WHERE type='table'
"""
)
for row in cur.fetchall():
cur.execute(f"DROP TABLE IF EXISTS {row[0]}")
self._conn_pool.close()
self.start()
super().reset_state()
@trace_method("SqliteDB.setup_migrations", OpenTelemetryGranularity.ALL)
@override
def setup_migrations(self) -> None:
with self.tx() as cur:
cur.execute(
"""
CREATE TABLE IF NOT EXISTS migrations (
dir TEXT NOT NULL,
version INTEGER NOT NULL,
filename TEXT NOT NULL,
sql TEXT NOT NULL,
hash TEXT NOT NULL,
PRIMARY KEY (dir, version)
)
"""
)
@trace_method("SqliteDB.migrations_initialized", OpenTelemetryGranularity.ALL)
@override
def migrations_initialized(self) -> bool:
with self.tx() as cur:
cur.execute(
"""SELECT count(*) FROM sqlite_master
WHERE type='table' AND name='migrations'"""
)
if cur.fetchone()[0] == 0:
return False
else:
return True
@trace_method("SqliteDB.db_migrations", OpenTelemetryGranularity.ALL)
@override
def db_migrations(self, dir: Traversable) -> Sequence[Migration]:
with self.tx() as cur:
cur.execute(
"""
SELECT dir, version, filename, sql, hash
FROM migrations
WHERE dir = ?
ORDER BY version ASC
""",
(dir.name,),
)
migrations = []
for row in cur.fetchall():
found_dir = cast(str, row[0])
found_version = cast(int, row[1])
found_filename = cast(str, row[2])
found_sql = cast(str, row[3])
found_hash = cast(str, row[4])
migrations.append(
Migration(
dir=found_dir,
version=found_version,
filename=found_filename,
sql=found_sql,
hash=found_hash,
scope=self.migration_scope(),
)
)
return migrations
@override
def apply_migration(self, cur: base.Cursor, migration: Migration) -> None:
cur.executescript(migration["sql"])
cur.execute(
"""
INSERT INTO migrations (dir, version, filename, sql, hash)
VALUES (?, ?, ?, ?, ?)
""",
(
migration["dir"],
migration["version"],
migration["filename"],
migration["sql"],
migration["hash"],
),
)
@staticmethod
@override
def uuid_from_db(value: Optional[Any]) -> Optional[UUID]:
return UUID(value) if value is not None else None
@staticmethod
@override
def uuid_to_db(uuid: Optional[UUID]) -> Optional[Any]:
return str(uuid) if uuid is not None else None
@staticmethod
@override
def unique_constraint_error() -> Type[BaseException]:
return sqlite3.IntegrityError
def vacuum(self, timeout: int = 5) -> None:
"""Runs VACUUM on the database. `timeout` is the maximum time to wait for an exclusive lock in seconds."""
conn = self._conn_pool.connect()
conn.execute(f"PRAGMA busy_timeout = {int(timeout) * 1000}")
conn.execute("VACUUM")
conn.execute(
"""
INSERT INTO maintenance_log (operation, timestamp)
VALUES ('vacuum', CURRENT_TIMESTAMP)
"""
)
|