File size: 7,669 Bytes
065b6ad
 
 
 
d9fa664
065b6ad
 
 
d9fa664
 
 
 
 
 
 
a82e32f
065b6ad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d9fa664
065b6ad
 
26b5c93
 
065b6ad
26b5c93
065b6ad
 
 
 
 
 
 
 
 
 
 
97177b4
065b6ad
 
 
 
 
 
 
 
 
 
 
d9fa664
065b6ad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
from pymilvus import MilvusClient, DataType
import numpy as np
import concurrent.futures
from pymilvus import Collection
import os

class MilvusManager:
    def __init__(self, milvus_uri, collection_name, create_collection, dim=128):
  
        #import environ variables from .env
        import dotenv
         # Load the .env file
        dotenv_file = dotenv.find_dotenv()
        dotenv.load_dotenv(dotenv_file)

        self.client = MilvusClient(uri=milvus_uri)
        self.collection_name = collection_name
        self.dim = dim

        if self.client.has_collection(collection_name=self.collection_name):
            self.client.load_collection(collection_name=self.collection_name)
            print("Loaded existing collection.")
        elif create_collection:
            self.create_collection()
            self.create_index()

    def create_collection(self):
        if self.client.has_collection(collection_name=self.collection_name):
            print("Collection already exists.")
            return

        schema = self.client.create_schema(
            auto_id=True,
            enable_dynamic_fields=True,
        )
        schema.add_field(field_name="pk", datatype=DataType.INT64, is_primary=True)
        schema.add_field(
            field_name="vector", datatype=DataType.FLOAT_VECTOR, dim=self.dim
        )
        schema.add_field(field_name="seq_id", datatype=DataType.INT16)
        schema.add_field(field_name="doc_id", datatype=DataType.INT64)
        schema.add_field(field_name="doc", datatype=DataType.VARCHAR, max_length=65535)

        self.client.create_collection(
            collection_name=self.collection_name, schema=schema
        )

    def create_index(self):
        index_params = self.client.prepare_index_params()

        index_params.add_index(
            field_name="vector",
            metric_type="COSINE",
            index_type="IVF_FLAT",
            index_name="vector_index",
            params={ "nlist": 128 }
        )

        self.client.create_index(
            collection_name=self.collection_name, index_params=index_params, sync=True
        )

    def search(self, data, topk):
         # Retrieve all collection names from the Milvus client.
        collections = self.client.list_collections()
        
        # Set search parameters (here, using Inner Product metric).
        search_params = {"metric_type": "COSINE", "params": {}} #default metric type is "IP"
        
        # Set to store unique (doc_id, collection_name) pairs across all collections.
        doc_collection_pairs = set()

        # Query each collection individually
        for collection in collections:
            self.client.load_collection(collection_name=collection)
            print("collection loaded:"+ collection)
            results = self.client.search(
                collection,
                data,
                limit=int(os.environ["topk"]),  # Adjust limit per collection as needed. (default is 50)
                output_fields=["vector", "seq_id", "doc_id"],
                search_params=search_params,
            )
            # Accumulate document IDs along with their originating collection.
            for r_id in range(len(results)):
                for r in range(len(results[r_id])):
                    doc_id = results[r_id][r]["entity"]["doc_id"]
                    doc_collection_pairs.add((doc_id, collection))

        scores = []

        def rerank_single_doc(doc_id, data, client, collection_name):
            # Query for detailed document vectors in the given collection.
            doc_colbert_vecs = client.query(
                collection_name=collection_name,
                filter=f"doc_id in [{doc_id}, {doc_id + 1}]",
                output_fields=["seq_id", "vector", "doc"],
                limit=16380,
            )
            # Stack the vectors for dot product computation.
            doc_vecs = np.vstack(
                [doc_colbert_vecs[i]["vector"] for i in range(len(doc_colbert_vecs))]
            )
            # Compute a similarity score via dot product.
            score = np.dot(data, doc_vecs.T).max(1).sum()
            return (score, doc_id, collection_name)

        # Use a thread pool to rerank each document concurrently.
        with concurrent.futures.ThreadPoolExecutor(max_workers=300) as executor:
            futures = {
                executor.submit(rerank_single_doc, doc_id, data, self.client, collection): (doc_id, collection)
                for doc_id, collection in doc_collection_pairs
            }
            for future in concurrent.futures.as_completed(futures):
                score, doc_id, collection = future.result()
                scores.append((score, doc_id, collection)) 
                #doc_id is page number!
        
        # Sort the reranked results by score in descending order.
        scores.sort(key=lambda x: x[0], reverse=True)
        # Unload the collection after search to free memory.
        self.client.release_collection(collection_name=collection)
        
        return scores[:topk] if len(scores) >= topk else scores
        """
        search_params = {"metric_type": "IP", "params": {}}
        results = self.client.search(
            self.collection_name,
            data,
            limit=50,
            output_fields=["vector", "seq_id", "doc_id"],
            search_params=search_params,
        )
        doc_ids = {result["entity"]["doc_id"] for result in results[0]}

        scores = []

        def rerank_single_doc(doc_id, data, client, collection_name):
            doc_colbert_vecs = client.query(
                collection_name=collection_name,
                filter=f"doc_id in [{doc_id}, {doc_id + 1}]",
                output_fields=["seq_id", "vector", "doc"],
                limit=1000,
            )
            doc_vecs = np.vstack(
                [doc["vector"] for doc in doc_colbert_vecs]
            )
            score = np.dot(data, doc_vecs.T).max(1).sum()
            return score, doc_id

        with concurrent.futures.ThreadPoolExecutor(max_workers=300) as executor:
            futures = {
                executor.submit(
                    rerank_single_doc, doc_id, data, self.client, self.collection_name
                ): doc_id
                for doc_id in doc_ids
            }
            for future in concurrent.futures.as_completed(futures):
                score, doc_id = future.result()
                scores.append((score, doc_id))

        scores.sort(key=lambda x: x[0], reverse=True)
        return scores[:topk]
        """

    def insert(self, data):
        colbert_vecs = data["colbert_vecs"]
        seq_length = len(colbert_vecs)
        doc_ids = [data["doc_id"]] * seq_length
        seq_ids = list(range(seq_length))
        docs = [""] * seq_length
        docs[0] = data["filepath"]

        self.client.insert(
            self.collection_name,
            [
                {
                    "vector": colbert_vecs[i],
                    "seq_id": seq_ids[i],
                    "doc_id": doc_ids[i],
                    "doc": docs[i],
                }
                for i in range(seq_length)
            ],
        )

    def get_images_as_doc(self, images_with_vectors):
        return [
            {
                "colbert_vecs": image["colbert_vecs"],
                "doc_id": idx,
                "filepath": image["filepath"],
            }
            for idx, image in enumerate(images_with_vectors)
        ]

    def insert_images_data(self, image_data):
        data = self.get_images_as_doc(image_data)
        for item in data:
            self.insert(item)