demo / milvus_manager.py
Kazel
change
97177b4
from pymilvus import MilvusClient, DataType
import numpy as np
import concurrent.futures
from pymilvus import Collection
import os
class MilvusManager:
def __init__(self, milvus_uri, collection_name, create_collection, dim=128):
#import environ variables from .env
import dotenv
# Load the .env file
dotenv_file = dotenv.find_dotenv()
dotenv.load_dotenv(dotenv_file)
self.client = MilvusClient(uri=milvus_uri)
self.collection_name = collection_name
self.dim = dim
if self.client.has_collection(collection_name=self.collection_name):
self.client.load_collection(collection_name=self.collection_name)
print("Loaded existing collection.")
elif create_collection:
self.create_collection()
self.create_index()
def create_collection(self):
if self.client.has_collection(collection_name=self.collection_name):
print("Collection already exists.")
return
schema = self.client.create_schema(
auto_id=True,
enable_dynamic_fields=True,
)
schema.add_field(field_name="pk", datatype=DataType.INT64, is_primary=True)
schema.add_field(
field_name="vector", datatype=DataType.FLOAT_VECTOR, dim=self.dim
)
schema.add_field(field_name="seq_id", datatype=DataType.INT16)
schema.add_field(field_name="doc_id", datatype=DataType.INT64)
schema.add_field(field_name="doc", datatype=DataType.VARCHAR, max_length=65535)
self.client.create_collection(
collection_name=self.collection_name, schema=schema
)
def create_index(self):
index_params = self.client.prepare_index_params()
index_params.add_index(
field_name="vector",
metric_type="COSINE",
index_type="IVF_FLAT",
index_name="vector_index",
params={ "nlist": 128 }
)
self.client.create_index(
collection_name=self.collection_name, index_params=index_params, sync=True
)
def search(self, data, topk):
# Retrieve all collection names from the Milvus client.
collections = self.client.list_collections()
# Set search parameters (here, using Inner Product metric).
search_params = {"metric_type": "COSINE", "params": {}} #default metric type is "IP"
# Set to store unique (doc_id, collection_name) pairs across all collections.
doc_collection_pairs = set()
# Query each collection individually
for collection in collections:
self.client.load_collection(collection_name=collection)
print("collection loaded:"+ collection)
results = self.client.search(
collection,
data,
limit=int(os.environ["topk"]), # Adjust limit per collection as needed. (default is 50)
output_fields=["vector", "seq_id", "doc_id"],
search_params=search_params,
)
# Accumulate document IDs along with their originating collection.
for r_id in range(len(results)):
for r in range(len(results[r_id])):
doc_id = results[r_id][r]["entity"]["doc_id"]
doc_collection_pairs.add((doc_id, collection))
scores = []
def rerank_single_doc(doc_id, data, client, collection_name):
# Query for detailed document vectors in the given collection.
doc_colbert_vecs = client.query(
collection_name=collection_name,
filter=f"doc_id in [{doc_id}, {doc_id + 1}]",
output_fields=["seq_id", "vector", "doc"],
limit=16380,
)
# Stack the vectors for dot product computation.
doc_vecs = np.vstack(
[doc_colbert_vecs[i]["vector"] for i in range(len(doc_colbert_vecs))]
)
# Compute a similarity score via dot product.
score = np.dot(data, doc_vecs.T).max(1).sum()
return (score, doc_id, collection_name)
# Use a thread pool to rerank each document concurrently.
with concurrent.futures.ThreadPoolExecutor(max_workers=300) as executor:
futures = {
executor.submit(rerank_single_doc, doc_id, data, self.client, collection): (doc_id, collection)
for doc_id, collection in doc_collection_pairs
}
for future in concurrent.futures.as_completed(futures):
score, doc_id, collection = future.result()
scores.append((score, doc_id, collection))
#doc_id is page number!
# Sort the reranked results by score in descending order.
scores.sort(key=lambda x: x[0], reverse=True)
# Unload the collection after search to free memory.
self.client.release_collection(collection_name=collection)
return scores[:topk] if len(scores) >= topk else scores
"""
search_params = {"metric_type": "IP", "params": {}}
results = self.client.search(
self.collection_name,
data,
limit=50,
output_fields=["vector", "seq_id", "doc_id"],
search_params=search_params,
)
doc_ids = {result["entity"]["doc_id"] for result in results[0]}
scores = []
def rerank_single_doc(doc_id, data, client, collection_name):
doc_colbert_vecs = client.query(
collection_name=collection_name,
filter=f"doc_id in [{doc_id}, {doc_id + 1}]",
output_fields=["seq_id", "vector", "doc"],
limit=1000,
)
doc_vecs = np.vstack(
[doc["vector"] for doc in doc_colbert_vecs]
)
score = np.dot(data, doc_vecs.T).max(1).sum()
return score, doc_id
with concurrent.futures.ThreadPoolExecutor(max_workers=300) as executor:
futures = {
executor.submit(
rerank_single_doc, doc_id, data, self.client, self.collection_name
): doc_id
for doc_id in doc_ids
}
for future in concurrent.futures.as_completed(futures):
score, doc_id = future.result()
scores.append((score, doc_id))
scores.sort(key=lambda x: x[0], reverse=True)
return scores[:topk]
"""
def insert(self, data):
colbert_vecs = data["colbert_vecs"]
seq_length = len(colbert_vecs)
doc_ids = [data["doc_id"]] * seq_length
seq_ids = list(range(seq_length))
docs = [""] * seq_length
docs[0] = data["filepath"]
self.client.insert(
self.collection_name,
[
{
"vector": colbert_vecs[i],
"seq_id": seq_ids[i],
"doc_id": doc_ids[i],
"doc": docs[i],
}
for i in range(seq_length)
],
)
def get_images_as_doc(self, images_with_vectors):
return [
{
"colbert_vecs": image["colbert_vecs"],
"doc_id": idx,
"filepath": image["filepath"],
}
for idx, image in enumerate(images_with_vectors)
]
def insert_images_data(self, image_data):
data = self.get_images_as_doc(image_data)
for item in data:
self.insert(item)