from pymilvus import MilvusClient, DataType import numpy as np import concurrent.futures from pymilvus import Collection import os class MilvusManager: def __init__(self, milvus_uri, collection_name, create_collection, dim=128): #import environ variables from .env import dotenv # Load the .env file dotenv_file = dotenv.find_dotenv() dotenv.load_dotenv(dotenv_file) self.client = MilvusClient(uri=milvus_uri) self.collection_name = collection_name self.dim = dim if self.client.has_collection(collection_name=self.collection_name): self.client.load_collection(collection_name=self.collection_name) print("Loaded existing collection.") elif create_collection: self.create_collection() self.create_index() def create_collection(self): if self.client.has_collection(collection_name=self.collection_name): print("Collection already exists.") return schema = self.client.create_schema( auto_id=True, enable_dynamic_fields=True, ) schema.add_field(field_name="pk", datatype=DataType.INT64, is_primary=True) schema.add_field( field_name="vector", datatype=DataType.FLOAT_VECTOR, dim=self.dim ) schema.add_field(field_name="seq_id", datatype=DataType.INT16) schema.add_field(field_name="doc_id", datatype=DataType.INT64) schema.add_field(field_name="doc", datatype=DataType.VARCHAR, max_length=65535) self.client.create_collection( collection_name=self.collection_name, schema=schema ) def create_index(self): index_params = self.client.prepare_index_params() index_params.add_index( field_name="vector", metric_type="COSINE", index_type="IVF_FLAT", index_name="vector_index", params={ "nlist": 128 } ) self.client.create_index( collection_name=self.collection_name, index_params=index_params, sync=True ) def search(self, data, topk): # Retrieve all collection names from the Milvus client. collections = self.client.list_collections() # Set search parameters (here, using Inner Product metric). search_params = {"metric_type": "COSINE", "params": {}} #default metric type is "IP" # Set to store unique (doc_id, collection_name) pairs across all collections. doc_collection_pairs = set() # Query each collection individually for collection in collections: self.client.load_collection(collection_name=collection) print("collection loaded:"+ collection) results = self.client.search( collection, data, limit=int(os.environ["topk"]), # Adjust limit per collection as needed. (default is 50) output_fields=["vector", "seq_id", "doc_id"], search_params=search_params, ) # Accumulate document IDs along with their originating collection. for r_id in range(len(results)): for r in range(len(results[r_id])): doc_id = results[r_id][r]["entity"]["doc_id"] doc_collection_pairs.add((doc_id, collection)) scores = [] def rerank_single_doc(doc_id, data, client, collection_name): # Query for detailed document vectors in the given collection. doc_colbert_vecs = client.query( collection_name=collection_name, filter=f"doc_id in [{doc_id}, {doc_id + 1}]", output_fields=["seq_id", "vector", "doc"], limit=16380, ) # Stack the vectors for dot product computation. doc_vecs = np.vstack( [doc_colbert_vecs[i]["vector"] for i in range(len(doc_colbert_vecs))] ) # Compute a similarity score via dot product. score = np.dot(data, doc_vecs.T).max(1).sum() return (score, doc_id, collection_name) # Use a thread pool to rerank each document concurrently. with concurrent.futures.ThreadPoolExecutor(max_workers=300) as executor: futures = { executor.submit(rerank_single_doc, doc_id, data, self.client, collection): (doc_id, collection) for doc_id, collection in doc_collection_pairs } for future in concurrent.futures.as_completed(futures): score, doc_id, collection = future.result() scores.append((score, doc_id, collection)) #doc_id is page number! # Sort the reranked results by score in descending order. scores.sort(key=lambda x: x[0], reverse=True) # Unload the collection after search to free memory. self.client.release_collection(collection_name=collection) return scores[:topk] if len(scores) >= topk else scores """ search_params = {"metric_type": "IP", "params": {}} results = self.client.search( self.collection_name, data, limit=50, output_fields=["vector", "seq_id", "doc_id"], search_params=search_params, ) doc_ids = {result["entity"]["doc_id"] for result in results[0]} scores = [] def rerank_single_doc(doc_id, data, client, collection_name): doc_colbert_vecs = client.query( collection_name=collection_name, filter=f"doc_id in [{doc_id}, {doc_id + 1}]", output_fields=["seq_id", "vector", "doc"], limit=1000, ) doc_vecs = np.vstack( [doc["vector"] for doc in doc_colbert_vecs] ) score = np.dot(data, doc_vecs.T).max(1).sum() return score, doc_id with concurrent.futures.ThreadPoolExecutor(max_workers=300) as executor: futures = { executor.submit( rerank_single_doc, doc_id, data, self.client, self.collection_name ): doc_id for doc_id in doc_ids } for future in concurrent.futures.as_completed(futures): score, doc_id = future.result() scores.append((score, doc_id)) scores.sort(key=lambda x: x[0], reverse=True) return scores[:topk] """ def insert(self, data): colbert_vecs = data["colbert_vecs"] seq_length = len(colbert_vecs) doc_ids = [data["doc_id"]] * seq_length seq_ids = list(range(seq_length)) docs = [""] * seq_length docs[0] = data["filepath"] self.client.insert( self.collection_name, [ { "vector": colbert_vecs[i], "seq_id": seq_ids[i], "doc_id": doc_ids[i], "doc": docs[i], } for i in range(seq_length) ], ) def get_images_as_doc(self, images_with_vectors): return [ { "colbert_vecs": image["colbert_vecs"], "doc_id": idx, "filepath": image["filepath"], } for idx, image in enumerate(images_with_vectors) ] def insert_images_data(self, image_data): data = self.get_images_as_doc(image_data) for item in data: self.insert(item)