AmmarFahmy
adding all files
105b369
from typing import Optional, List, Union
from hashlib import md5
try:
from sqlalchemy.dialects import postgresql
from sqlalchemy.engine import create_engine, Engine
from sqlalchemy.inspection import inspect
from sqlalchemy.orm import Session, sessionmaker
from sqlalchemy.schema import MetaData, Table, Column
from sqlalchemy.sql.expression import text, func, select
from sqlalchemy.types import DateTime, String
except ImportError:
raise ImportError("`sqlalchemy` not installed")
try:
from pgvector.sqlalchemy import Vector
except ImportError:
raise ImportError("`pgvector` not installed")
from phi.document import Document
from phi.embedder import Embedder
from phi.vectordb.base import VectorDb
from phi.vectordb.distance import Distance
from phi.vectordb.pgvector.index import Ivfflat, HNSW
from phi.utils.log import logger
class PgVector(VectorDb):
def __init__(
self,
collection: str,
schema: Optional[str] = "ai",
db_url: Optional[str] = None,
db_engine: Optional[Engine] = None,
embedder: Optional[Embedder] = None,
distance: Distance = Distance.cosine,
index: Optional[Union[Ivfflat, HNSW]] = HNSW(),
):
_engine: Optional[Engine] = db_engine
if _engine is None and db_url is not None:
_engine = create_engine(db_url)
if _engine is None:
raise ValueError("Must provide either db_url or db_engine")
# Collection attributes
self.collection: str = collection
self.schema: Optional[str] = schema
# Database attributes
self.db_url: Optional[str] = db_url
self.db_engine: Engine = _engine
self.metadata: MetaData = MetaData(schema=self.schema)
# Embedder for embedding the document contents
_embedder = embedder
if _embedder is None:
from phi.embedder.openai import OpenAIEmbedder
_embedder = OpenAIEmbedder()
self.embedder: Embedder = _embedder
self.dimensions: int = self.embedder.dimensions
# Distance metric
self.distance: Distance = distance
# Index for the collection
self.index: Optional[Union[Ivfflat, HNSW]] = index
# Database session
self.Session: sessionmaker[Session] = sessionmaker(bind=self.db_engine)
# Database table for the collection
self.table: Table = self.get_table()
def get_table(self) -> Table:
return Table(
self.collection,
self.metadata,
Column("name", String),
Column("meta_data", postgresql.JSONB, server_default=text("'{}'::jsonb")),
Column("content", postgresql.TEXT),
Column("embedding", Vector(self.dimensions)),
Column("usage", postgresql.JSONB),
Column("created_at", DateTime(timezone=True), server_default=text("now()")),
Column("updated_at", DateTime(timezone=True), onupdate=text("now()")),
Column("content_hash", String),
extend_existing=True,
)
def table_exists(self) -> bool:
logger.debug(f"Checking if table exists: {self.table.name}")
try:
return inspect(self.db_engine).has_table(self.table.name, schema=self.schema)
except Exception as e:
logger.error(e)
return False
def create(self) -> None:
if not self.table_exists():
with self.Session() as sess:
with sess.begin():
logger.debug("Creating extension: vector")
sess.execute(text("create extension if not exists vector;"))
if self.schema is not None:
logger.debug(f"Creating schema: {self.schema}")
sess.execute(text(f"create schema if not exists {self.schema};"))
logger.debug(f"Creating table: {self.collection}")
self.table.create(self.db_engine)
def doc_exists(self, document: Document) -> bool:
"""
Validating if the document exists or not
Args:
document (Document): Document to validate
"""
columns = [self.table.c.name, self.table.c.content_hash]
with self.Session() as sess:
with sess.begin():
cleaned_content = document.content.replace("\x00", "\ufffd")
stmt = select(*columns).where(self.table.c.content_hash == md5(cleaned_content.encode()).hexdigest())
result = sess.execute(stmt).first()
return result is not None
def name_exists(self, name: str) -> bool:
"""
Validate if a row with this name exists or not
Args:
name (str): Name to validate
"""
with self.Session() as sess:
with sess.begin():
stmt = select(self.table.c.name).where(self.table.c.name == name)
result = sess.execute(stmt).first()
return result is not None
def insert(self, documents: List[Document], batch_size: int = 10) -> None:
with self.Session() as sess:
counter = 0
for document in documents:
document.embed(embedder=self.embedder)
cleaned_content = document.content.replace("\x00", "\ufffd")
stmt = postgresql.insert(self.table).values(
name=document.name,
meta_data=document.meta_data,
content=cleaned_content,
embedding=document.embedding,
usage=document.usage,
content_hash=md5(cleaned_content.encode()).hexdigest(),
)
sess.execute(stmt)
counter += 1
logger.debug(f"Inserted document: {document.name} ({document.meta_data})")
# Commit every `batch_size` documents
if counter >= batch_size:
sess.commit()
logger.debug(f"Committed {counter} documents")
counter = 0
# Commit any remaining documents
if counter > 0:
sess.commit()
logger.debug(f"Committed {counter} documents")
def upsert(self, documents: List[Document]) -> None:
"""
Upsert documents into the database.
Args:
documents (List[Document]): List of documents to upsert
"""
with self.Session() as sess:
with sess.begin():
for document in documents:
document.embed(embedder=self.embedder)
cleaned_content = document.content.replace("\x00", "\ufffd")
stmt = postgresql.insert(self.table).values(
name=document.name,
meta_data=document.meta_data,
content=cleaned_content,
embedding=document.embedding,
usage=document.usage,
content_hash=md5(cleaned_content.encode()).hexdigest(),
)
stmt = stmt.on_conflict_do_update(
index_elements=["name", "content_hash"],
set_=dict(
meta_data=document.meta_data,
content=stmt.excluded.content,
embedding=stmt.excluded.embedding,
usage=stmt.excluded.usage,
),
)
sess.execute(stmt)
logger.debug(f"Upserted document: {document.name} ({document.meta_data})")
def search(self, query: str, limit: int = 5) -> List[Document]:
query_embedding = self.embedder.get_embedding(query)
if query_embedding is None:
logger.error(f"Error getting embedding for Query: {query}")
return []
columns = [
self.table.c.name,
self.table.c.meta_data,
self.table.c.content,
self.table.c.embedding,
self.table.c.usage,
]
stmt = select(*columns)
if self.distance == Distance.l2:
stmt = stmt.order_by(self.table.c.embedding.max_inner_product(query_embedding))
if self.distance == Distance.cosine:
stmt = stmt.order_by(self.table.c.embedding.cosine_distance(query_embedding))
if self.distance == Distance.max_inner_product:
stmt = stmt.order_by(self.table.c.embedding.max_inner_product(query_embedding))
stmt = stmt.limit(limit=limit)
logger.debug(f"Query: {stmt}")
# Get neighbors
with self.Session() as sess:
with sess.begin():
if self.index is not None:
if isinstance(self.index, Ivfflat):
sess.execute(text(f"SET LOCAL ivfflat.probes = {self.index.probes}"))
elif isinstance(self.index, HNSW):
sess.execute(text(f"SET LOCAL hnsw.ef_search = {self.index.ef_search}"))
neighbors = sess.execute(stmt).fetchall() or []
# Build search results
search_results: List[Document] = []
for neighbor in neighbors:
search_results.append(
Document(
name=neighbor.name,
meta_data=neighbor.meta_data,
content=neighbor.content,
embedder=self.embedder,
embedding=neighbor.embedding,
usage=neighbor.usage,
)
)
return search_results
def delete(self) -> None:
if self.table_exists():
logger.debug(f"Deleting table: {self.collection}")
self.table.drop(self.db_engine)
def exists(self) -> bool:
return self.table_exists()
def get_count(self) -> int:
with self.Session() as sess:
with sess.begin():
stmt = select(func.count(self.table.c.name)).select_from(self.table)
result = sess.execute(stmt).scalar()
if result is not None:
return int(result)
return 0
def optimize(self) -> None:
from math import sqrt
logger.debug("==== Optimizing Vector DB ====")
if self.index is None:
return
if self.index.name is None:
_type = "ivfflat" if isinstance(self.index, Ivfflat) else "hnsw"
self.index.name = f"{self.collection}_{_type}_index"
index_distance = "vector_cosine_ops"
if self.distance == Distance.l2:
index_distance = "vector_l2_ops"
if self.distance == Distance.max_inner_product:
index_distance = "vector_ip_ops"
if isinstance(self.index, Ivfflat):
num_lists = self.index.lists
if self.index.dynamic_lists:
total_records = self.get_count()
logger.debug(f"Number of records: {total_records}")
if total_records < 1000000:
num_lists = int(total_records / 1000)
elif total_records > 1000000:
num_lists = int(sqrt(total_records))
with self.Session() as sess:
with sess.begin():
logger.debug(f"Setting configuration: {self.index.configuration}")
for key, value in self.index.configuration.items():
sess.execute(text(f"SET {key} = '{value}';"))
logger.debug(
f"Creating Ivfflat index with lists: {num_lists}, probes: {self.index.probes} "
f"and distance metric: {index_distance}"
)
sess.execute(text(f"SET ivfflat.probes = {self.index.probes};"))
sess.execute(
text(
f"CREATE INDEX IF NOT EXISTS {self.index.name} ON {self.table} "
f"USING ivfflat (embedding {index_distance}) "
f"WITH (lists = {num_lists});"
)
)
elif isinstance(self.index, HNSW):
with self.Session() as sess:
with sess.begin():
logger.debug(f"Setting configuration: {self.index.configuration}")
for key, value in self.index.configuration.items():
sess.execute(text(f"SET {key} = '{value}';"))
logger.debug(
f"Creating HNSW index with m: {self.index.m}, ef_construction: {self.index.ef_construction} "
f"and distance metric: {index_distance}"
)
sess.execute(
text(
f"CREATE INDEX IF NOT EXISTS {self.index.name} ON {self.table} "
f"USING hnsw (embedding {index_distance}) "
f"WITH (m = {self.index.m}, ef_construction = {self.index.ef_construction});"
)
)
logger.debug("==== Optimized Vector DB ====")
def clear(self) -> bool:
from sqlalchemy import delete
with self.Session() as sess:
with sess.begin():
stmt = delete(self.table)
sess.execute(stmt)
return True