rag_lite / src /raglite /_database.py
EL GHAFRAOUI AYOUB
C
54f5afe
"""PostgreSQL or SQLite database tables for RAGLite."""
import datetime
import json
from functools import lru_cache
from hashlib import sha256
from pathlib import Path
from typing import Any
import numpy as np
from litellm import get_model_info # type: ignore[attr-defined]
from markdown_it import MarkdownIt
from pydantic import ConfigDict
from sqlalchemy.engine import Engine, make_url
from sqlmodel import (
JSON,
Column,
Field,
Relationship,
Session,
SQLModel,
create_engine,
text,
)
from raglite._config import RAGLiteConfig
from raglite._litellm import LlamaCppPythonLLM
from raglite._typing import Embedding, FloatMatrix, FloatVector, PickledObject
def hash_bytes(data: bytes, max_len: int = 16) -> str:
"""Hash bytes to a hexadecimal string."""
return sha256(data, usedforsecurity=False).hexdigest()[:max_len]
class Document(SQLModel, table=True):
"""A document."""
# Enable JSON columns.
model_config = ConfigDict(arbitrary_types_allowed=True) # type: ignore[assignment]
# Table columns.
id: str = Field(..., primary_key=True)
filename: str
url: str | None = Field(default=None)
metadata_: dict[str, Any] = Field(default_factory=dict, sa_column=Column("metadata", JSON))
# Add relationships so we can access document.chunks and document.evals.
chunks: list["Chunk"] = Relationship(back_populates="document")
evals: list["Eval"] = Relationship(back_populates="document")
@staticmethod
def from_path(doc_path: Path, **kwargs: Any) -> "Document":
"""Create a document from a file path."""
return Document(
id=hash_bytes(doc_path.read_bytes()),
filename=doc_path.name,
metadata_={
"size": doc_path.stat().st_size,
"created": doc_path.stat().st_ctime,
"modified": doc_path.stat().st_mtime,
**kwargs,
},
)
class Chunk(SQLModel, table=True):
"""A document chunk."""
# Enable JSON columns.
model_config = ConfigDict(arbitrary_types_allowed=True) # type: ignore[assignment]
# Table columns.
id: str = Field(..., primary_key=True)
document_id: str = Field(..., foreign_key="document.id", index=True)
index: int = Field(..., index=True)
headings: str
body: str
metadata_: dict[str, Any] = Field(default_factory=dict, sa_column=Column("metadata", JSON))
# Add relationships so we can access chunk.document and chunk.embeddings.
document: Document = Relationship(back_populates="chunks")
embeddings: list["ChunkEmbedding"] = Relationship(back_populates="chunk")
@staticmethod
def from_body(
document_id: str,
index: int,
body: str,
headings: str = "",
**kwargs: Any,
) -> "Chunk":
"""Create a chunk from Markdown."""
return Chunk(
id=hash_bytes(body.encode()),
document_id=document_id,
index=index,
headings=headings,
body=body,
metadata_=kwargs,
)
def extract_headings(self) -> str:
"""Extract Markdown headings from the chunk, starting from the current Markdown headings."""
md = MarkdownIt()
heading_lines = [""] * 10
level = None
for doc in (self.headings, self.body):
for token in md.parse(doc):
if token.type == "heading_open":
level = int(token.tag[1])
elif token.type == "heading_close":
level = None
elif level is not None:
heading_content = token.content.strip().replace("\n", " ")
heading_lines[level] = ("#" * level) + " " + heading_content
heading_lines[level + 1 :] = [""] * len(heading_lines[level + 1 :])
headings = "\n".join([heading for heading in heading_lines if heading])
return headings
@property
def embedding_matrix(self) -> FloatMatrix:
"""Return this chunk's multi-vector embedding matrix."""
# Uses the relationship chunk.embeddings to access the chunk_embedding table.
return np.vstack([embedding.embedding[np.newaxis, :] for embedding in self.embeddings])
def __hash__(self) -> int:
return hash(self.id)
def __repr__(self) -> str:
return json.dumps(
{
"id": self.id,
"document_id": self.document_id,
"index": self.index,
"headings": self.headings,
"body": self.body[:100],
"metadata": self.metadata_,
},
indent=4,
)
def __str__(self) -> str:
"""Context representation of this chunk."""
return f"{self.headings.strip()}\n\n{self.body.strip()}".strip()
class ChunkEmbedding(SQLModel, table=True):
"""A (sub-)chunk embedding."""
__tablename__ = "chunk_embedding"
# Enable Embedding columns.
model_config = ConfigDict(arbitrary_types_allowed=True) # type: ignore[assignment]
# Table columns.
id: int = Field(..., primary_key=True)
chunk_id: str = Field(..., foreign_key="chunk.id", index=True)
embedding: FloatVector = Field(..., sa_column=Column(Embedding(dim=-1)))
# Add relationship so we can access embedding.chunk.
chunk: Chunk = Relationship(back_populates="embeddings")
@classmethod
def set_embedding_dim(cls, dim: int) -> None:
"""Modify the embedding column's dimension after class definition."""
cls.__table__.c["embedding"].type.dim = dim # type: ignore[attr-defined]
class IndexMetadata(SQLModel, table=True):
"""Vector and keyword search index metadata."""
__tablename__ = "index_metadata"
# Enable PickledObject columns.
model_config = ConfigDict(arbitrary_types_allowed=True) # type: ignore[assignment]
# Table columns.
id: str = Field(..., primary_key=True)
version: datetime.datetime = Field(
default_factory=lambda: datetime.datetime.now(datetime.timezone.utc)
)
metadata_: dict[str, Any] = Field(
default_factory=dict, sa_column=Column("metadata", PickledObject)
)
@staticmethod
@lru_cache(maxsize=4)
def _get(id_: str, *, config: RAGLiteConfig | None = None) -> dict[str, Any] | None:
engine = create_database_engine(config)
with Session(engine) as session:
index_metadata_record = session.get(IndexMetadata, id_)
if index_metadata_record is None:
return None
return index_metadata_record.metadata_
@staticmethod
def get(id_: str = "default", *, config: RAGLiteConfig | None = None) -> dict[str, Any]:
metadata = IndexMetadata._get(id_, config=config) or {}
return metadata
class Eval(SQLModel, table=True):
"""A RAG evaluation example."""
__tablename__ = "eval"
# Enable JSON columns.
model_config = ConfigDict(arbitrary_types_allowed=True) # type: ignore[assignment]
# Table columns.
id: str = Field(..., primary_key=True)
document_id: str = Field(..., foreign_key="document.id", index=True)
chunk_ids: list[str] = Field(default_factory=list, sa_column=Column(JSON))
question: str
contexts: list[str] = Field(default_factory=list, sa_column=Column(JSON))
ground_truth: str
metadata_: dict[str, Any] = Field(default_factory=dict, sa_column=Column("metadata", JSON))
# Add relationship so we can access eval.document.
document: Document = Relationship(back_populates="evals")
@staticmethod
def from_chunks(
question: str,
contexts: list[Chunk],
ground_truth: str,
**kwargs: Any,
) -> "Eval":
"""Create a chunk from Markdown."""
document_id = contexts[0].document_id
chunk_ids = [context.id for context in contexts]
return Eval(
id=hash_bytes(f"{document_id}-{chunk_ids}-{question}".encode()),
document_id=document_id,
chunk_ids=chunk_ids,
question=question,
contexts=[str(context) for context in contexts],
ground_truth=ground_truth,
metadata_=kwargs,
)
@lru_cache(maxsize=1)
def create_database_engine(config: RAGLiteConfig | None = None) -> Engine:
"""Create a database engine and initialize it."""
# Parse the database URL and validate that the database backend is supported.
config = config or RAGLiteConfig()
db_url = make_url(config.db_url)
db_backend = db_url.get_backend_name()
# Update database configuration.
connect_args = {}
if db_backend == "postgresql":
# Select the pg8000 driver if not set (psycopg2 is the default), and prefer SSL.
if "+" not in db_url.drivername:
db_url = db_url.set(drivername="postgresql+pg8000")
# Support setting the sslmode for pg8000.
if "pg8000" in db_url.drivername and "sslmode" in db_url.query:
query = dict(db_url.query)
if query.pop("sslmode") != "disable":
connect_args["ssl_context"] = True
db_url = db_url.set(query=query)
elif db_backend == "sqlite":
# Optimize SQLite performance.
pragmas = {"journal_mode": "WAL", "synchronous": "NORMAL"}
db_url = db_url.update_query_dict(pragmas, append=True)
else:
error_message = "RAGLite only supports PostgreSQL and SQLite."
raise ValueError(error_message)
# Create the engine.
engine = create_engine(db_url, pool_pre_ping=True, connect_args=connect_args)
# Install database extensions.
if db_backend == "postgresql":
with Session(engine) as session:
session.execute(text("CREATE EXTENSION IF NOT EXISTS vector;"))
session.commit()
# If the user has configured a llama-cpp-python model, we ensure that LiteLLM's model info is up
# to date by loading that LLM.
if config.embedder.startswith("llama-cpp-python"):
_ = LlamaCppPythonLLM.llm(config.embedder, embedding=True)
llm_provider = "llama-cpp-python" if config.embedder.startswith("llama-cpp") else None
model_info = get_model_info(config.embedder, custom_llm_provider=llm_provider)
embedding_dim = model_info.get("output_vector_size") or -1
assert embedding_dim > 0
# Create all SQLModel tables.
ChunkEmbedding.set_embedding_dim(embedding_dim)
SQLModel.metadata.create_all(engine)
# Create backend-specific indexes.
if db_backend == "postgresql":
# Create a keyword search index with `tsvector` and a vector search index with `pgvector`.
with Session(engine) as session:
metrics = {"cosine": "cosine", "dot": "ip", "euclidean": "l2", "l1": "l1", "l2": "l2"}
session.execute(
text("""
CREATE INDEX IF NOT EXISTS keyword_search_chunk_index ON chunk USING GIN (to_tsvector('simple', body));
""")
)
session.execute(
text(f"""
CREATE INDEX IF NOT EXISTS vector_search_chunk_index ON chunk_embedding
USING hnsw (
(embedding::halfvec({embedding_dim}))
halfvec_{metrics[config.vector_search_index_metric]}_ops
);
""")
)
session.commit()
elif db_backend == "sqlite":
# Create a virtual table for keyword search on the chunk table.
# We use the chunk table as an external content table [1] to avoid duplicating the data.
# [1] https://www.sqlite.org/fts5.html#external_content_tables
with Session(engine) as session:
session.execute(
text("""
CREATE VIRTUAL TABLE IF NOT EXISTS keyword_search_chunk_index USING fts5(body, content='chunk', content_rowid='rowid');
""")
)
session.execute(
text("""
CREATE TRIGGER IF NOT EXISTS keyword_search_chunk_index_auto_insert AFTER INSERT ON chunk BEGIN
INSERT INTO keyword_search_chunk_index(rowid, body) VALUES (new.rowid, new.body);
END;
""")
)
session.execute(
text("""
CREATE TRIGGER IF NOT EXISTS keyword_search_chunk_index_auto_delete AFTER DELETE ON chunk BEGIN
INSERT INTO keyword_search_chunk_index(keyword_search_chunk_index, rowid, body) VALUES('delete', old.rowid, old.body);
END;
""")
)
session.execute(
text("""
CREATE TRIGGER IF NOT EXISTS keyword_search_chunk_index_auto_update AFTER UPDATE ON chunk BEGIN
INSERT INTO keyword_search_chunk_index(keyword_search_chunk_index, rowid, body) VALUES('delete', old.rowid, old.body);
INSERT INTO keyword_search_chunk_index(rowid, body) VALUES (new.rowid, new.body);
END;
""")
)
session.commit()
return engine