File size: 4,703 Bytes
e676d24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
from qdrant_client import QdrantClient
from qdrant_client.http import models
from tqdm import tqdm
import os
import time
import numpy as np
from loguru import logger
import stamina
from typing import Any, List, Tuple, Type, Literal, Optional, Union, Dict

class MyQdrantClient:
    def __init__(self, path: str):
        self.qdrant_client = QdrantClient(path=path)
        logger.debug(f"Qdrant client created at {path}")

    def create_collection(self, collection_name: str, vector_dim: int = 128, vector_type: str = "colbert"):
        if vector_type == "colbert":
            self.qdrant_client.create_collection(
                collection_name=collection_name,
                on_disk_payload=True,  # store the payload on disk
                vectors_config=models.VectorParams(
                    size=vector_dim,
                    distance=models.Distance.COSINE,
                    on_disk=True, # move original vectors to disk
                    multivector_config=models.MultiVectorConfig(
                        comparator=models.MultiVectorComparator.MAX_SIM
                    ),
                    #quantization_config=models.BinaryQuantization(
                    #binary=models.BinaryQuantizationConfig(
                    #    always_ram=True  # keep only quantized vectors in RAM
                    #    ),
                    #),
                ),
            )
        elif vector_type == "dense":
            self.qdrant_client.create_collection(
                collection_name=collection_name,
                on_disk_payload=True,  # store the payload on disk
                vectors_config=models.VectorParams(
                    size=vector_dim,
                    distance=models.Distance.COSINE,
                    on_disk=True, # move original vectors to disk
                ),
            )
        else:
            raise ValueError(f"Vector type {vector_type} not supported")

        logger.debug(f"Qdrant collection of type {vector_type} : {collection_name} created")
    
    def delete_collection(self, collection_name: str):
        self.qdrant_client.delete_collection(collection_name=collection_name)

    @stamina.retry(on=Exception, attempts=3) # retry mechanism if an exception occurs during the operation
    def upsert_to_qdrant(self, batch, collection_name: str):
        try:
            self.qdrant_client.upsert(
                collection_name=collection_name,
                points=batch,
                wait=False,
            )
        except Exception as e:
            logger.error(f"Error during upsert: {e}")
            return False
        return True

    def upsert_multivector(self, index: int, multivector_input_list: list[Any], collection_name: str):
        try:
            points = []
            for j, multivector in enumerate(multivector_input_list):
                points.append(
                    models.PointStruct(
                        id=index + j,  # we just use the index as the ID
                        vector=multivector,  # This is now a list of vectors
                        payload={
                            "source": "user uploaded data"
                        },  # can also add other metadata/data
                    )
                )
            # Upload points to Qdrant
        
            self.upsert_to_qdrant(points, collection_name)
        except Exception as e:
            logger.error(f"Vector DB client - error during upsert: {e}")
    
    def query_multivector(self, multivector_input, collection_name: str, top_k:int=10) -> list[int]:
        try:
            #logger.debug(f"Number of vector: {len(multivector_input)}")
            #logger.debug(f"Vector dim: {len(multivector_input[0])}")

            start_time = time.time()
            search_result = self.qdrant_client.query_points(
                collection_name=collection_name,
                query=multivector_input,
                limit=top_k,
                # timeout=100,
                # search_params=models.SearchParams(
                #     quantization=models.QuantizationSearchParams(
                #         ignore=False,
                #         rescore=True,
                #         oversampling=2.0,
                #     )
                # )
            )
            end_time = time.time()
            elapsed_time = end_time - start_time
            logger.debug(f"Search completed in {elapsed_time:.4f} seconds")

            result = [x.id for x in search_result.points]
            return result

        except Exception as e:
            logger.error(f"Error during query: {e}")
            return None

    def __del__(self):
        self.qdrant_client.close()