root
upload
e676d24
from qdrant_client import QdrantClient
from qdrant_client.http import models
from tqdm import tqdm
import os
import time
import numpy as np
from loguru import logger
import stamina
from typing import Any, List, Tuple, Type, Literal, Optional, Union, Dict
class MyQdrantClient:
def __init__(self, path: str):
self.qdrant_client = QdrantClient(path=path)
logger.debug(f"Qdrant client created at {path}")
def create_collection(self, collection_name: str, vector_dim: int = 128, vector_type: str = "colbert"):
if vector_type == "colbert":
self.qdrant_client.create_collection(
collection_name=collection_name,
on_disk_payload=True, # store the payload on disk
vectors_config=models.VectorParams(
size=vector_dim,
distance=models.Distance.COSINE,
on_disk=True, # move original vectors to disk
multivector_config=models.MultiVectorConfig(
comparator=models.MultiVectorComparator.MAX_SIM
),
#quantization_config=models.BinaryQuantization(
#binary=models.BinaryQuantizationConfig(
# always_ram=True # keep only quantized vectors in RAM
# ),
#),
),
)
elif vector_type == "dense":
self.qdrant_client.create_collection(
collection_name=collection_name,
on_disk_payload=True, # store the payload on disk
vectors_config=models.VectorParams(
size=vector_dim,
distance=models.Distance.COSINE,
on_disk=True, # move original vectors to disk
),
)
else:
raise ValueError(f"Vector type {vector_type} not supported")
logger.debug(f"Qdrant collection of type {vector_type} : {collection_name} created")
def delete_collection(self, collection_name: str):
self.qdrant_client.delete_collection(collection_name=collection_name)
@stamina.retry(on=Exception, attempts=3) # retry mechanism if an exception occurs during the operation
def upsert_to_qdrant(self, batch, collection_name: str):
try:
self.qdrant_client.upsert(
collection_name=collection_name,
points=batch,
wait=False,
)
except Exception as e:
logger.error(f"Error during upsert: {e}")
return False
return True
def upsert_multivector(self, index: int, multivector_input_list: list[Any], collection_name: str):
try:
points = []
for j, multivector in enumerate(multivector_input_list):
points.append(
models.PointStruct(
id=index + j, # we just use the index as the ID
vector=multivector, # This is now a list of vectors
payload={
"source": "user uploaded data"
}, # can also add other metadata/data
)
)
# Upload points to Qdrant
self.upsert_to_qdrant(points, collection_name)
except Exception as e:
logger.error(f"Vector DB client - error during upsert: {e}")
def query_multivector(self, multivector_input, collection_name: str, top_k:int=10) -> list[int]:
try:
#logger.debug(f"Number of vector: {len(multivector_input)}")
#logger.debug(f"Vector dim: {len(multivector_input[0])}")
start_time = time.time()
search_result = self.qdrant_client.query_points(
collection_name=collection_name,
query=multivector_input,
limit=top_k,
# timeout=100,
# search_params=models.SearchParams(
# quantization=models.QuantizationSearchParams(
# ignore=False,
# rescore=True,
# oversampling=2.0,
# )
# )
)
end_time = time.time()
elapsed_time = end_time - start_time
logger.debug(f"Search completed in {elapsed_time:.4f} seconds")
result = [x.id for x in search_result.points]
return result
except Exception as e:
logger.error(f"Error during query: {e}")
return None
def __del__(self):
self.qdrant_client.close()