Spaces:
Runtime error
Runtime error
from hashlib import md5 | |
from typing import List, Optional | |
try: | |
from qdrant_client import QdrantClient # noqa: F401 | |
from qdrant_client.http import models | |
except ImportError: | |
raise ImportError( | |
"The `qdrant-client` package is not installed. " | |
"Please install it via `pip install pip install qdrant-client`." | |
) | |
from phi.document import Document | |
from phi.embedder import Embedder | |
from phi.embedder.openai import OpenAIEmbedder | |
from phi.vectordb.base import VectorDb | |
from phi.vectordb.distance import Distance | |
from phi.utils.log import logger | |
class Qdrant(VectorDb): | |
def __init__( | |
self, | |
collection: str, | |
embedder: Embedder = OpenAIEmbedder(), | |
distance: Distance = Distance.cosine, | |
location: Optional[str] = None, | |
url: Optional[str] = None, | |
port: Optional[int] = 6333, | |
grpc_port: int = 6334, | |
prefer_grpc: bool = False, | |
https: Optional[bool] = None, | |
api_key: Optional[str] = None, | |
prefix: Optional[str] = None, | |
timeout: Optional[float] = None, | |
host: Optional[str] = None, | |
path: Optional[str] = None, | |
**kwargs, | |
): | |
# Collection attributes | |
self.collection: str = collection | |
# Embedder for embedding the document contents | |
self.embedder: Embedder = embedder | |
self.dimensions: int = self.embedder.dimensions | |
# Distance metric | |
self.distance: Distance = distance | |
# Qdrant client instance | |
self._client: Optional[QdrantClient] = None | |
# Qdrant client arguments | |
self.location: Optional[str] = location | |
self.url: Optional[str] = url | |
self.port: Optional[int] = port | |
self.grpc_port: int = grpc_port | |
self.prefer_grpc: bool = prefer_grpc | |
self.https: Optional[bool] = https | |
self.api_key: Optional[str] = api_key | |
self.prefix: Optional[str] = prefix | |
self.timeout: Optional[float] = timeout | |
self.host: Optional[str] = host | |
self.path: Optional[str] = path | |
# Qdrant client kwargs | |
self.kwargs = kwargs | |
def client(self) -> QdrantClient: | |
if self._client is None: | |
logger.debug("Creating Qdrant Client") | |
self._client = QdrantClient( | |
location=self.location, | |
url=self.url, | |
port=self.port, | |
grpc_port=self.grpc_port, | |
prefer_grpc=self.prefer_grpc, | |
https=self.https, | |
api_key=self.api_key, | |
prefix=self.prefix, | |
timeout=self.timeout, | |
host=self.host, | |
path=self.path, | |
**self.kwargs, | |
) | |
return self._client | |
def create(self) -> None: | |
# Collection distance | |
_distance = models.Distance.COSINE | |
if self.distance == Distance.l2: | |
_distance = models.Distance.EUCLID | |
elif self.distance == Distance.max_inner_product: | |
_distance = models.Distance.DOT | |
if not self.exists(): | |
logger.debug(f"Creating collection: {self.collection}") | |
self.client.create_collection( | |
collection_name=self.collection, | |
vectors_config=models.VectorParams(size=self.dimensions, distance=_distance), | |
) | |
def doc_exists(self, document: Document) -> bool: | |
""" | |
Validating if the document exists or not | |
Args: | |
document (Document): Document to validate | |
""" | |
if self.client: | |
cleaned_content = document.content.replace("\x00", "\ufffd") | |
doc_id = md5(cleaned_content.encode()).hexdigest() | |
collection_points = self.client.retrieve( | |
collection_name=self.collection, | |
ids=[doc_id], | |
) | |
return len(collection_points) > 0 | |
return False | |
def name_exists(self, name: str) -> bool: | |
raise NotImplementedError | |
def insert(self, documents: List[Document], batch_size: int = 10) -> None: | |
logger.debug(f"Inserting {len(documents)} documents") | |
points = [] | |
for document in documents: | |
document.embed(embedder=self.embedder) | |
cleaned_content = document.content.replace("\x00", "\ufffd") | |
doc_id = md5(cleaned_content.encode()).hexdigest() | |
points.append( | |
models.PointStruct( | |
id=doc_id, | |
vector=document.embedding, | |
payload={ | |
"name": document.name, | |
"meta_data": document.meta_data, | |
"content": cleaned_content, | |
"usage": document.usage, | |
}, | |
) | |
) | |
logger.debug(f"Inserted document: {document.name} ({document.meta_data})") | |
if len(points) > 0: | |
self.client.upsert(collection_name=self.collection, wait=False, points=points) | |
logger.debug(f"Upsert {len(points)} documents") | |
def upsert(self, documents: List[Document]) -> None: | |
""" | |
Upsert documents into the database. | |
Args: | |
documents (List[Document]): List of documents to upsert | |
""" | |
logger.debug("Redirecting the request to insert") | |
self.insert(documents) | |
def search(self, query: str, limit: int = 5) -> List[Document]: | |
query_embedding = self.embedder.get_embedding(query) | |
if query_embedding is None: | |
logger.error(f"Error getting embedding for Query: {query}") | |
return [] | |
results = self.client.search( | |
collection_name=self.collection, | |
query_vector=query_embedding, | |
with_vectors=True, | |
with_payload=True, | |
limit=limit, | |
) | |
# Build search results | |
search_results: List[Document] = [] | |
for result in results: | |
if result.payload is None: | |
continue | |
search_results.append( | |
Document( | |
name=result.payload["name"], | |
meta_data=result.payload["meta_data"], | |
content=result.payload["content"], | |
embedder=self.embedder, | |
embedding=result.vector, | |
usage=result.payload["usage"], | |
) | |
) | |
return search_results | |
def delete(self) -> None: | |
if self.exists(): | |
logger.debug(f"Deleting collection: {self.collection}") | |
self.client.delete_collection(self.collection) | |
def exists(self) -> bool: | |
if self.client: | |
collections_response: models.CollectionsResponse = self.client.get_collections() | |
collections: List[models.CollectionDescription] = collections_response.collections | |
for collection in collections: | |
if collection.name == self.collection: | |
# collection.status == models.CollectionStatus.GREEN | |
return True | |
return False | |
def get_count(self) -> int: | |
count_result: models.CountResult = self.client.count(collection_name=self.collection, exact=True) | |
return count_result.count | |
def optimize(self) -> None: | |
pass | |
def clear(self) -> bool: | |
return False | |