AmmarFahmy
adding all files
105b369
from hashlib import md5
from typing import List, Optional
try:
from qdrant_client import QdrantClient # noqa: F401
from qdrant_client.http import models
except ImportError:
raise ImportError(
"The `qdrant-client` package is not installed. "
"Please install it via `pip install pip install qdrant-client`."
)
from phi.document import Document
from phi.embedder import Embedder
from phi.embedder.openai import OpenAIEmbedder
from phi.vectordb.base import VectorDb
from phi.vectordb.distance import Distance
from phi.utils.log import logger
class Qdrant(VectorDb):
def __init__(
self,
collection: str,
embedder: Embedder = OpenAIEmbedder(),
distance: Distance = Distance.cosine,
location: Optional[str] = None,
url: Optional[str] = None,
port: Optional[int] = 6333,
grpc_port: int = 6334,
prefer_grpc: bool = False,
https: Optional[bool] = None,
api_key: Optional[str] = None,
prefix: Optional[str] = None,
timeout: Optional[float] = None,
host: Optional[str] = None,
path: Optional[str] = None,
**kwargs,
):
# Collection attributes
self.collection: str = collection
# Embedder for embedding the document contents
self.embedder: Embedder = embedder
self.dimensions: int = self.embedder.dimensions
# Distance metric
self.distance: Distance = distance
# Qdrant client instance
self._client: Optional[QdrantClient] = None
# Qdrant client arguments
self.location: Optional[str] = location
self.url: Optional[str] = url
self.port: Optional[int] = port
self.grpc_port: int = grpc_port
self.prefer_grpc: bool = prefer_grpc
self.https: Optional[bool] = https
self.api_key: Optional[str] = api_key
self.prefix: Optional[str] = prefix
self.timeout: Optional[float] = timeout
self.host: Optional[str] = host
self.path: Optional[str] = path
# Qdrant client kwargs
self.kwargs = kwargs
@property
def client(self) -> QdrantClient:
if self._client is None:
logger.debug("Creating Qdrant Client")
self._client = QdrantClient(
location=self.location,
url=self.url,
port=self.port,
grpc_port=self.grpc_port,
prefer_grpc=self.prefer_grpc,
https=self.https,
api_key=self.api_key,
prefix=self.prefix,
timeout=self.timeout,
host=self.host,
path=self.path,
**self.kwargs,
)
return self._client
def create(self) -> None:
# Collection distance
_distance = models.Distance.COSINE
if self.distance == Distance.l2:
_distance = models.Distance.EUCLID
elif self.distance == Distance.max_inner_product:
_distance = models.Distance.DOT
if not self.exists():
logger.debug(f"Creating collection: {self.collection}")
self.client.create_collection(
collection_name=self.collection,
vectors_config=models.VectorParams(size=self.dimensions, distance=_distance),
)
def doc_exists(self, document: Document) -> bool:
"""
Validating if the document exists or not
Args:
document (Document): Document to validate
"""
if self.client:
cleaned_content = document.content.replace("\x00", "\ufffd")
doc_id = md5(cleaned_content.encode()).hexdigest()
collection_points = self.client.retrieve(
collection_name=self.collection,
ids=[doc_id],
)
return len(collection_points) > 0
return False
def name_exists(self, name: str) -> bool:
raise NotImplementedError
def insert(self, documents: List[Document], batch_size: int = 10) -> None:
logger.debug(f"Inserting {len(documents)} documents")
points = []
for document in documents:
document.embed(embedder=self.embedder)
cleaned_content = document.content.replace("\x00", "\ufffd")
doc_id = md5(cleaned_content.encode()).hexdigest()
points.append(
models.PointStruct(
id=doc_id,
vector=document.embedding,
payload={
"name": document.name,
"meta_data": document.meta_data,
"content": cleaned_content,
"usage": document.usage,
},
)
)
logger.debug(f"Inserted document: {document.name} ({document.meta_data})")
if len(points) > 0:
self.client.upsert(collection_name=self.collection, wait=False, points=points)
logger.debug(f"Upsert {len(points)} documents")
def upsert(self, documents: List[Document]) -> None:
"""
Upsert documents into the database.
Args:
documents (List[Document]): List of documents to upsert
"""
logger.debug("Redirecting the request to insert")
self.insert(documents)
def search(self, query: str, limit: int = 5) -> List[Document]:
query_embedding = self.embedder.get_embedding(query)
if query_embedding is None:
logger.error(f"Error getting embedding for Query: {query}")
return []
results = self.client.search(
collection_name=self.collection,
query_vector=query_embedding,
with_vectors=True,
with_payload=True,
limit=limit,
)
# Build search results
search_results: List[Document] = []
for result in results:
if result.payload is None:
continue
search_results.append(
Document(
name=result.payload["name"],
meta_data=result.payload["meta_data"],
content=result.payload["content"],
embedder=self.embedder,
embedding=result.vector,
usage=result.payload["usage"],
)
)
return search_results
def delete(self) -> None:
if self.exists():
logger.debug(f"Deleting collection: {self.collection}")
self.client.delete_collection(self.collection)
def exists(self) -> bool:
if self.client:
collections_response: models.CollectionsResponse = self.client.get_collections()
collections: List[models.CollectionDescription] = collections_response.collections
for collection in collections:
if collection.name == self.collection:
# collection.status == models.CollectionStatus.GREEN
return True
return False
def get_count(self) -> int:
count_result: models.CountResult = self.client.count(collection_name=self.collection, exact=True)
return count_result.count
def optimize(self) -> None:
pass
def clear(self) -> bool:
return False