Spaces:
Running
Running
"""PostgreSQL or SQLite database tables for RAGLite.""" | |
import datetime | |
import json | |
from functools import lru_cache | |
from hashlib import sha256 | |
from pathlib import Path | |
from typing import Any | |
import numpy as np | |
from litellm import get_model_info # type: ignore[attr-defined] | |
from markdown_it import MarkdownIt | |
from pydantic import ConfigDict | |
from sqlalchemy.engine import Engine, make_url | |
from sqlmodel import ( | |
JSON, | |
Column, | |
Field, | |
Relationship, | |
Session, | |
SQLModel, | |
create_engine, | |
text, | |
) | |
from raglite._config import RAGLiteConfig | |
from raglite._litellm import LlamaCppPythonLLM | |
from raglite._typing import Embedding, FloatMatrix, FloatVector, PickledObject | |
def hash_bytes(data: bytes, max_len: int = 16) -> str: | |
"""Hash bytes to a hexadecimal string.""" | |
return sha256(data, usedforsecurity=False).hexdigest()[:max_len] | |
class Document(SQLModel, table=True): | |
"""A document.""" | |
# Enable JSON columns. | |
model_config = ConfigDict(arbitrary_types_allowed=True) # type: ignore[assignment] | |
# Table columns. | |
id: str = Field(..., primary_key=True) | |
filename: str | |
url: str | None = Field(default=None) | |
metadata_: dict[str, Any] = Field(default_factory=dict, sa_column=Column("metadata", JSON)) | |
# Add relationships so we can access document.chunks and document.evals. | |
chunks: list["Chunk"] = Relationship(back_populates="document") | |
evals: list["Eval"] = Relationship(back_populates="document") | |
def from_path(doc_path: Path, **kwargs: Any) -> "Document": | |
"""Create a document from a file path.""" | |
return Document( | |
id=hash_bytes(doc_path.read_bytes()), | |
filename=doc_path.name, | |
metadata_={ | |
"size": doc_path.stat().st_size, | |
"created": doc_path.stat().st_ctime, | |
"modified": doc_path.stat().st_mtime, | |
**kwargs, | |
}, | |
) | |
class Chunk(SQLModel, table=True): | |
"""A document chunk.""" | |
# Enable JSON columns. | |
model_config = ConfigDict(arbitrary_types_allowed=True) # type: ignore[assignment] | |
# Table columns. | |
id: str = Field(..., primary_key=True) | |
document_id: str = Field(..., foreign_key="document.id", index=True) | |
index: int = Field(..., index=True) | |
headings: str | |
body: str | |
metadata_: dict[str, Any] = Field(default_factory=dict, sa_column=Column("metadata", JSON)) | |
# Add relationships so we can access chunk.document and chunk.embeddings. | |
document: Document = Relationship(back_populates="chunks") | |
embeddings: list["ChunkEmbedding"] = Relationship(back_populates="chunk") | |
def from_body( | |
document_id: str, | |
index: int, | |
body: str, | |
headings: str = "", | |
**kwargs: Any, | |
) -> "Chunk": | |
"""Create a chunk from Markdown.""" | |
return Chunk( | |
id=hash_bytes(body.encode()), | |
document_id=document_id, | |
index=index, | |
headings=headings, | |
body=body, | |
metadata_=kwargs, | |
) | |
def extract_headings(self) -> str: | |
"""Extract Markdown headings from the chunk, starting from the current Markdown headings.""" | |
md = MarkdownIt() | |
heading_lines = [""] * 10 | |
level = None | |
for doc in (self.headings, self.body): | |
for token in md.parse(doc): | |
if token.type == "heading_open": | |
level = int(token.tag[1]) | |
elif token.type == "heading_close": | |
level = None | |
elif level is not None: | |
heading_content = token.content.strip().replace("\n", " ") | |
heading_lines[level] = ("#" * level) + " " + heading_content | |
heading_lines[level + 1 :] = [""] * len(heading_lines[level + 1 :]) | |
headings = "\n".join([heading for heading in heading_lines if heading]) | |
return headings | |
def embedding_matrix(self) -> FloatMatrix: | |
"""Return this chunk's multi-vector embedding matrix.""" | |
# Uses the relationship chunk.embeddings to access the chunk_embedding table. | |
return np.vstack([embedding.embedding[np.newaxis, :] for embedding in self.embeddings]) | |
def __hash__(self) -> int: | |
return hash(self.id) | |
def __repr__(self) -> str: | |
return json.dumps( | |
{ | |
"id": self.id, | |
"document_id": self.document_id, | |
"index": self.index, | |
"headings": self.headings, | |
"body": self.body[:100], | |
"metadata": self.metadata_, | |
}, | |
indent=4, | |
) | |
def __str__(self) -> str: | |
"""Context representation of this chunk.""" | |
return f"{self.headings.strip()}\n\n{self.body.strip()}".strip() | |
class ChunkEmbedding(SQLModel, table=True): | |
"""A (sub-)chunk embedding.""" | |
__tablename__ = "chunk_embedding" | |
# Enable Embedding columns. | |
model_config = ConfigDict(arbitrary_types_allowed=True) # type: ignore[assignment] | |
# Table columns. | |
id: int = Field(..., primary_key=True) | |
chunk_id: str = Field(..., foreign_key="chunk.id", index=True) | |
embedding: FloatVector = Field(..., sa_column=Column(Embedding(dim=-1))) | |
# Add relationship so we can access embedding.chunk. | |
chunk: Chunk = Relationship(back_populates="embeddings") | |
def set_embedding_dim(cls, dim: int) -> None: | |
"""Modify the embedding column's dimension after class definition.""" | |
cls.__table__.c["embedding"].type.dim = dim # type: ignore[attr-defined] | |
class IndexMetadata(SQLModel, table=True): | |
"""Vector and keyword search index metadata.""" | |
__tablename__ = "index_metadata" | |
# Enable PickledObject columns. | |
model_config = ConfigDict(arbitrary_types_allowed=True) # type: ignore[assignment] | |
# Table columns. | |
id: str = Field(..., primary_key=True) | |
version: datetime.datetime = Field( | |
default_factory=lambda: datetime.datetime.now(datetime.timezone.utc) | |
) | |
metadata_: dict[str, Any] = Field( | |
default_factory=dict, sa_column=Column("metadata", PickledObject) | |
) | |
def _get(id_: str, *, config: RAGLiteConfig | None = None) -> dict[str, Any] | None: | |
engine = create_database_engine(config) | |
with Session(engine) as session: | |
index_metadata_record = session.get(IndexMetadata, id_) | |
if index_metadata_record is None: | |
return None | |
return index_metadata_record.metadata_ | |
def get(id_: str = "default", *, config: RAGLiteConfig | None = None) -> dict[str, Any]: | |
metadata = IndexMetadata._get(id_, config=config) or {} | |
return metadata | |
class Eval(SQLModel, table=True): | |
"""A RAG evaluation example.""" | |
__tablename__ = "eval" | |
# Enable JSON columns. | |
model_config = ConfigDict(arbitrary_types_allowed=True) # type: ignore[assignment] | |
# Table columns. | |
id: str = Field(..., primary_key=True) | |
document_id: str = Field(..., foreign_key="document.id", index=True) | |
chunk_ids: list[str] = Field(default_factory=list, sa_column=Column(JSON)) | |
question: str | |
contexts: list[str] = Field(default_factory=list, sa_column=Column(JSON)) | |
ground_truth: str | |
metadata_: dict[str, Any] = Field(default_factory=dict, sa_column=Column("metadata", JSON)) | |
# Add relationship so we can access eval.document. | |
document: Document = Relationship(back_populates="evals") | |
def from_chunks( | |
question: str, | |
contexts: list[Chunk], | |
ground_truth: str, | |
**kwargs: Any, | |
) -> "Eval": | |
"""Create a chunk from Markdown.""" | |
document_id = contexts[0].document_id | |
chunk_ids = [context.id for context in contexts] | |
return Eval( | |
id=hash_bytes(f"{document_id}-{chunk_ids}-{question}".encode()), | |
document_id=document_id, | |
chunk_ids=chunk_ids, | |
question=question, | |
contexts=[str(context) for context in contexts], | |
ground_truth=ground_truth, | |
metadata_=kwargs, | |
) | |
def create_database_engine(config: RAGLiteConfig | None = None) -> Engine: | |
"""Create a database engine and initialize it.""" | |
# Parse the database URL and validate that the database backend is supported. | |
config = config or RAGLiteConfig() | |
db_url = make_url(config.db_url) | |
db_backend = db_url.get_backend_name() | |
# Update database configuration. | |
connect_args = {} | |
if db_backend == "postgresql": | |
# Select the pg8000 driver if not set (psycopg2 is the default), and prefer SSL. | |
if "+" not in db_url.drivername: | |
db_url = db_url.set(drivername="postgresql+pg8000") | |
# Support setting the sslmode for pg8000. | |
if "pg8000" in db_url.drivername and "sslmode" in db_url.query: | |
query = dict(db_url.query) | |
if query.pop("sslmode") != "disable": | |
connect_args["ssl_context"] = True | |
db_url = db_url.set(query=query) | |
elif db_backend == "sqlite": | |
# Optimize SQLite performance. | |
pragmas = {"journal_mode": "WAL", "synchronous": "NORMAL"} | |
db_url = db_url.update_query_dict(pragmas, append=True) | |
else: | |
error_message = "RAGLite only supports PostgreSQL and SQLite." | |
raise ValueError(error_message) | |
# Create the engine. | |
engine = create_engine(db_url, pool_pre_ping=True, connect_args=connect_args) | |
# Install database extensions. | |
if db_backend == "postgresql": | |
with Session(engine) as session: | |
session.execute(text("CREATE EXTENSION IF NOT EXISTS vector;")) | |
session.commit() | |
# If the user has configured a llama-cpp-python model, we ensure that LiteLLM's model info is up | |
# to date by loading that LLM. | |
if config.embedder.startswith("llama-cpp-python"): | |
_ = LlamaCppPythonLLM.llm(config.embedder, embedding=True) | |
llm_provider = "llama-cpp-python" if config.embedder.startswith("llama-cpp") else None | |
model_info = get_model_info(config.embedder, custom_llm_provider=llm_provider) | |
embedding_dim = model_info.get("output_vector_size") or -1 | |
assert embedding_dim > 0 | |
# Create all SQLModel tables. | |
ChunkEmbedding.set_embedding_dim(embedding_dim) | |
SQLModel.metadata.create_all(engine) | |
# Create backend-specific indexes. | |
if db_backend == "postgresql": | |
# Create a keyword search index with `tsvector` and a vector search index with `pgvector`. | |
with Session(engine) as session: | |
metrics = {"cosine": "cosine", "dot": "ip", "euclidean": "l2", "l1": "l1", "l2": "l2"} | |
session.execute( | |
text(""" | |
CREATE INDEX IF NOT EXISTS keyword_search_chunk_index ON chunk USING GIN (to_tsvector('simple', body)); | |
""") | |
) | |
session.execute( | |
text(f""" | |
CREATE INDEX IF NOT EXISTS vector_search_chunk_index ON chunk_embedding | |
USING hnsw ( | |
(embedding::halfvec({embedding_dim})) | |
halfvec_{metrics[config.vector_search_index_metric]}_ops | |
); | |
""") | |
) | |
session.commit() | |
elif db_backend == "sqlite": | |
# Create a virtual table for keyword search on the chunk table. | |
# We use the chunk table as an external content table [1] to avoid duplicating the data. | |
# [1] https://www.sqlite.org/fts5.html#external_content_tables | |
with Session(engine) as session: | |
session.execute( | |
text(""" | |
CREATE VIRTUAL TABLE IF NOT EXISTS keyword_search_chunk_index USING fts5(body, content='chunk', content_rowid='rowid'); | |
""") | |
) | |
session.execute( | |
text(""" | |
CREATE TRIGGER IF NOT EXISTS keyword_search_chunk_index_auto_insert AFTER INSERT ON chunk BEGIN | |
INSERT INTO keyword_search_chunk_index(rowid, body) VALUES (new.rowid, new.body); | |
END; | |
""") | |
) | |
session.execute( | |
text(""" | |
CREATE TRIGGER IF NOT EXISTS keyword_search_chunk_index_auto_delete AFTER DELETE ON chunk BEGIN | |
INSERT INTO keyword_search_chunk_index(keyword_search_chunk_index, rowid, body) VALUES('delete', old.rowid, old.body); | |
END; | |
""") | |
) | |
session.execute( | |
text(""" | |
CREATE TRIGGER IF NOT EXISTS keyword_search_chunk_index_auto_update AFTER UPDATE ON chunk BEGIN | |
INSERT INTO keyword_search_chunk_index(keyword_search_chunk_index, rowid, body) VALUES('delete', old.rowid, old.body); | |
INSERT INTO keyword_search_chunk_index(rowid, body) VALUES (new.rowid, new.body); | |
END; | |
""") | |
) | |
session.commit() | |
return engine | |