multimodal_rag / milvus_manager.py
ej68okap
new code added
a53d884
# Import necessary modules
from pymilvus import MilvusClient, DataType # Milvus client and data type definitions
import numpy as np # For numerical operations
import concurrent.futures # For concurrent execution of tasks
class MilvusManager:
"""
A manager class for interacting with the Milvus database, handling collection creation,
data insertion, and search functionality.
"""
def __init__(self, milvus_uri, collection_name, create_collection, dim=128):
"""
Initialize the MilvusManager.
Args:
milvus_uri (str): URI for connecting to the Milvus server.
collection_name (str): Name of the collection in Milvus.
create_collection (bool): Whether to create a new collection.
dim (int): Dimensionality of the vector embeddings (default is 128).
"""
self.client = MilvusClient(uri=milvus_uri) # Initialize the Milvus client
self.collection_name = collection_name
self.dim = dim
# Load the collection if it exists, otherwise create it
if self.client.has_collection(collection_name=self.collection_name):
self.client.load_collection(collection_name)
if create_collection:
self.create_collection() # Create a new collection
self.create_index() # Create an index for the collection
def create_collection(self):
"""
Create a new collection in Milvus with a predefined schema.
"""
# Drop the collection if it already exists
if self.client.has_collection(collection_name=self.collection_name):
self.client.drop_collection(collection_name=self.collection_name)
# Define the schema for the collection
schema = self.client.create_schema(
auto_id=True, # Enable automatic ID assignment
enable_dynamic_fields=True, # Allow dynamic fields
)
schema.add_field(field_name="pk", datatype=DataType.INT64, is_primary=True) # Primary key
schema.add_field(
field_name="vector", datatype=DataType.FLOAT_VECTOR, dim=self.dim # Vector field
)
schema.add_field(field_name="seq_id", datatype=DataType.INT16) # Sequence ID
schema.add_field(field_name="doc_id", datatype=DataType.INT64) # Document ID
schema.add_field(field_name="doc", datatype=DataType.VARCHAR, max_length=65535) # Document path
# Create the collection with the specified schema
self.client.create_collection(
collection_name=self.collection_name, schema=schema
)
def create_index(self):
"""
Create an HNSW index for the vector field in the collection.
"""
# Release the collection before updating the index
self.client.release_collection(collection_name=self.collection_name)
self.client.drop_index(collection_name=self.collection_name, index_name="vector")
# Define the HNSW index parameters
index_params = self.client.prepare_index_params()
index_params.add_index(
field_name="vector",
index_name="vector_index",
index_type="HNSW", # Hierarchical Navigable Small World graph index
metric_type="IP", # Inner Product (dot product) as similarity metric
params={
"M": 16, # Number of candidate connections
"efConstruction": 500, # Construction complexity
},
)
# Create the index and synchronize with the server
self.client.create_index(
collection_name=self.collection_name, index_params=index_params, sync=True
)
def create_scalar_index(self):
"""
Create an inverted index for scalar fields such as document IDs.
"""
self.client.release_collection(collection_name=self.collection_name)
index_params = self.client.prepare_index_params()
index_params.add_index(
field_name="doc_id",
index_name="int32_index",
index_type="INVERTED", # Inverted index for scalar data
)
self.client.create_index(
collection_name=self.collection_name, index_params=index_params, sync=True
)
def search(self, data, topk, threshold=0.7):
"""
Search for the top-k most similar vectors in the collection, filtered by a relevance threshold.
Args:
data (array-like): Query vector.
topk (int): Number of top results to return.
threshold (float): Minimum score threshold for relevance (default is 0.5).
Returns:
list: Sorted list of top-k results that meet the threshold.
"""
search_params = {"metric_type": "IP", "params": {}} # Search parameters for Inner Product
results = self.client.search(
self.collection_name,
data,
limit=50, # Initial retrieval limit
output_fields=["vector", "seq_id", "doc_id"], # Fields to include in the output
search_params=search_params,
)
# Collect unique document IDs from the search results
doc_ids = set()
for r_id in range(len(results)):
for r in range(len(results[r_id])):
doc_ids.add(results[r_id][r]["entity"]["doc_id"])
scores = []
# Function to rerank a single document based on its relevance to the query
def rerank_single_doc(doc_id, data, client, collection_name):
doc_colbert_vecs = client.query(
collection_name=collection_name,
filter=f"doc_id in [{doc_id}, {doc_id + 1}]", # Query documents by ID
output_fields=["seq_id", "vector", "doc"], # Fields to retrieve
limit=1000, # Retrieve a maximum of 1000 vectors per document
)
# Compute the maximum similarity score for the document
doc_vecs = np.vstack(
[doc_colbert_vecs[i]["vector"] for i in range(len(doc_colbert_vecs))]
)
score = np.dot(data, doc_vecs.T).max(1).sum()
return (score, doc_id)
# Use multithreading to rerank documents in parallel
with concurrent.futures.ThreadPoolExecutor(max_workers=300) as executor:
futures = {
executor.submit(
rerank_single_doc, doc_id, data, self.client, self.collection_name
): doc_id
for doc_id in doc_ids
}
for future in concurrent.futures.as_completed(futures):
score, doc_id = future.result()
scores.append((score, doc_id))
# Filter scores by threshold
filtered_scores = [item for item in scores if item[0] >= threshold]
# Sort scores in descending order and return the top-k results
filtered_scores.sort(key=lambda x: x[0], reverse=True)
return filtered_scores[:topk] if len(filtered_scores) >= topk else filtered_scores
# def search(self, data, topk):
# """
# Search for the top-k most similar vectors in the collection.
# Args:
# data (array-like): Query vector.
# topk (int): Number of top results to return.
# Returns:
# list: Sorted list of top-k results.
# """
# search_params = {"metric_type": "IP", "params": {}} # Search parameters for Inner Product
# results = self.client.search(
# self.collection_name,
# data,
# limit=50, # Initial retrieval limit
# output_fields=["vector", "seq_id", "doc_id"], # Fields to include in the output
# search_params=search_params,
# )
# # Collect unique document IDs from the search results
# doc_ids = set()
# for r_id in range(len(results)):
# for r in range(len(results[r_id])):
# doc_ids.add(results[r_id][r]["entity"]["doc_id"])
# scores = []
# # Function to rerank a single document based on its relevance to the query
# def rerank_single_doc(doc_id, data, client, collection_name):
# doc_colbert_vecs = client.query(
# collection_name=collection_name,
# filter=f"doc_id in [{doc_id}, {doc_id + 1}]", # Query documents by ID
# output_fields=["seq_id", "vector", "doc"], # Fields to retrieve
# limit=1000, # Retrieve a maximum of 1000 vectors per document
# )
# # Compute the maximum similarity score for the document
# doc_vecs = np.vstack(
# [doc_colbert_vecs[i]["vector"] for i in range(len(doc_colbert_vecs))]
# )
# score = np.dot(data, doc_vecs.T).max(1).sum()
# return (score, doc_id)
# # Use multithreading to rerank documents in parallel
# with concurrent.futures.ThreadPoolExecutor(max_workers=300) as executor:
# futures = {
# executor.submit(
# rerank_single_doc, doc_id, data, self.client, self.collection_name
# ): doc_id
# for doc_id in doc_ids
# }
# for future in concurrent.futures.as_completed(futures):
# score, doc_id = future.result()
# scores.append((score, doc_id))
# # Sort scores in descending order and return the top-k results
# scores.sort(key=lambda x: x[0], reverse=True)
# return scores[:topk] if len(scores) >= topk else scores
def insert(self, data):
"""
Insert a batch of data into the collection.
Args:
data (dict): Dictionary containing vector embeddings and metadata.
"""
colbert_vecs = [vec for vec in data["colbert_vecs"]]
seq_length = len(colbert_vecs)
doc_ids = [data["doc_id"] for i in range(seq_length)]
seq_ids = list(range(seq_length))
docs = [""] * seq_length
docs[0] = data["filepath"] # Store file path in the first entry
# Insert the data into the collection
self.client.insert(
self.collection_name,
[
{
"vector": colbert_vecs[i],
"seq_id": seq_ids[i],
"doc_id": doc_ids[i],
"doc": docs[i],
}
for i in range(seq_length)
],
)
def get_images_as_doc(self, images_with_vectors: list):
"""
Convert image data with vectors into document-like format for insertion.
Args:
images_with_vectors (list): List of dictionaries containing image vectors and file paths.
Returns:
list: Transformed data ready for insertion.
"""
images_data = []
for i in range(len(images_with_vectors)):
data = {
"colbert_vecs": images_with_vectors[i]["colbert_vecs"],
"doc_id": i,
"filepath": images_with_vectors[i]["filepath"],
}
images_data.append(data)
return images_data
def insert_images_data(self, image_data):
"""
Insert processed image data into the collection.
Args:
image_data (list): List of image data dictionaries.
"""
data = self.get_images_as_doc(image_data)
for i in range(len(data)):
self.insert(data[i]) # Insert each item individually