Spaces:
Runtime error
Runtime error
# Import necessary modules | |
from pymilvus import MilvusClient, DataType # Milvus client and data type definitions | |
import numpy as np # For numerical operations | |
import concurrent.futures # For concurrent execution of tasks | |
class MilvusManager: | |
""" | |
A manager class for interacting with the Milvus database, handling collection creation, | |
data insertion, and search functionality. | |
""" | |
def __init__(self, milvus_uri, collection_name, create_collection, dim=128): | |
""" | |
Initialize the MilvusManager. | |
Args: | |
milvus_uri (str): URI for connecting to the Milvus server. | |
collection_name (str): Name of the collection in Milvus. | |
create_collection (bool): Whether to create a new collection. | |
dim (int): Dimensionality of the vector embeddings (default is 128). | |
""" | |
self.client = MilvusClient(uri=milvus_uri) # Initialize the Milvus client | |
self.collection_name = collection_name | |
self.dim = dim | |
# Load the collection if it exists, otherwise create it | |
if self.client.has_collection(collection_name=self.collection_name): | |
self.client.load_collection(collection_name) | |
if create_collection: | |
self.create_collection() # Create a new collection | |
self.create_index() # Create an index for the collection | |
def create_collection(self): | |
""" | |
Create a new collection in Milvus with a predefined schema. | |
""" | |
# Drop the collection if it already exists | |
if self.client.has_collection(collection_name=self.collection_name): | |
self.client.drop_collection(collection_name=self.collection_name) | |
# Define the schema for the collection | |
schema = self.client.create_schema( | |
auto_id=True, # Enable automatic ID assignment | |
enable_dynamic_fields=True, # Allow dynamic fields | |
) | |
schema.add_field(field_name="pk", datatype=DataType.INT64, is_primary=True) # Primary key | |
schema.add_field( | |
field_name="vector", datatype=DataType.FLOAT_VECTOR, dim=self.dim # Vector field | |
) | |
schema.add_field(field_name="seq_id", datatype=DataType.INT16) # Sequence ID | |
schema.add_field(field_name="doc_id", datatype=DataType.INT64) # Document ID | |
schema.add_field(field_name="doc", datatype=DataType.VARCHAR, max_length=65535) # Document path | |
# Create the collection with the specified schema | |
self.client.create_collection( | |
collection_name=self.collection_name, schema=schema | |
) | |
def create_index(self): | |
""" | |
Create an HNSW index for the vector field in the collection. | |
""" | |
# Release the collection before updating the index | |
self.client.release_collection(collection_name=self.collection_name) | |
self.client.drop_index(collection_name=self.collection_name, index_name="vector") | |
# Define the HNSW index parameters | |
index_params = self.client.prepare_index_params() | |
index_params.add_index( | |
field_name="vector", | |
index_name="vector_index", | |
index_type="HNSW", # Hierarchical Navigable Small World graph index | |
metric_type="IP", # Inner Product (dot product) as similarity metric | |
params={ | |
"M": 16, # Number of candidate connections | |
"efConstruction": 500, # Construction complexity | |
}, | |
) | |
# Create the index and synchronize with the server | |
self.client.create_index( | |
collection_name=self.collection_name, index_params=index_params, sync=True | |
) | |
def create_scalar_index(self): | |
""" | |
Create an inverted index for scalar fields such as document IDs. | |
""" | |
self.client.release_collection(collection_name=self.collection_name) | |
index_params = self.client.prepare_index_params() | |
index_params.add_index( | |
field_name="doc_id", | |
index_name="int32_index", | |
index_type="INVERTED", # Inverted index for scalar data | |
) | |
self.client.create_index( | |
collection_name=self.collection_name, index_params=index_params, sync=True | |
) | |
def search(self, data, topk, threshold=0.7): | |
""" | |
Search for the top-k most similar vectors in the collection, filtered by a relevance threshold. | |
Args: | |
data (array-like): Query vector. | |
topk (int): Number of top results to return. | |
threshold (float): Minimum score threshold for relevance (default is 0.5). | |
Returns: | |
list: Sorted list of top-k results that meet the threshold. | |
""" | |
search_params = {"metric_type": "IP", "params": {}} # Search parameters for Inner Product | |
results = self.client.search( | |
self.collection_name, | |
data, | |
limit=50, # Initial retrieval limit | |
output_fields=["vector", "seq_id", "doc_id"], # Fields to include in the output | |
search_params=search_params, | |
) | |
# Collect unique document IDs from the search results | |
doc_ids = set() | |
for r_id in range(len(results)): | |
for r in range(len(results[r_id])): | |
doc_ids.add(results[r_id][r]["entity"]["doc_id"]) | |
scores = [] | |
# Function to rerank a single document based on its relevance to the query | |
def rerank_single_doc(doc_id, data, client, collection_name): | |
doc_colbert_vecs = client.query( | |
collection_name=collection_name, | |
filter=f"doc_id in [{doc_id}, {doc_id + 1}]", # Query documents by ID | |
output_fields=["seq_id", "vector", "doc"], # Fields to retrieve | |
limit=1000, # Retrieve a maximum of 1000 vectors per document | |
) | |
# Compute the maximum similarity score for the document | |
doc_vecs = np.vstack( | |
[doc_colbert_vecs[i]["vector"] for i in range(len(doc_colbert_vecs))] | |
) | |
score = np.dot(data, doc_vecs.T).max(1).sum() | |
return (score, doc_id) | |
# Use multithreading to rerank documents in parallel | |
with concurrent.futures.ThreadPoolExecutor(max_workers=300) as executor: | |
futures = { | |
executor.submit( | |
rerank_single_doc, doc_id, data, self.client, self.collection_name | |
): doc_id | |
for doc_id in doc_ids | |
} | |
for future in concurrent.futures.as_completed(futures): | |
score, doc_id = future.result() | |
scores.append((score, doc_id)) | |
# Filter scores by threshold | |
filtered_scores = [item for item in scores if item[0] >= threshold] | |
# Sort scores in descending order and return the top-k results | |
filtered_scores.sort(key=lambda x: x[0], reverse=True) | |
return filtered_scores[:topk] if len(filtered_scores) >= topk else filtered_scores | |
# def search(self, data, topk): | |
# """ | |
# Search for the top-k most similar vectors in the collection. | |
# Args: | |
# data (array-like): Query vector. | |
# topk (int): Number of top results to return. | |
# Returns: | |
# list: Sorted list of top-k results. | |
# """ | |
# search_params = {"metric_type": "IP", "params": {}} # Search parameters for Inner Product | |
# results = self.client.search( | |
# self.collection_name, | |
# data, | |
# limit=50, # Initial retrieval limit | |
# output_fields=["vector", "seq_id", "doc_id"], # Fields to include in the output | |
# search_params=search_params, | |
# ) | |
# # Collect unique document IDs from the search results | |
# doc_ids = set() | |
# for r_id in range(len(results)): | |
# for r in range(len(results[r_id])): | |
# doc_ids.add(results[r_id][r]["entity"]["doc_id"]) | |
# scores = [] | |
# # Function to rerank a single document based on its relevance to the query | |
# def rerank_single_doc(doc_id, data, client, collection_name): | |
# doc_colbert_vecs = client.query( | |
# collection_name=collection_name, | |
# filter=f"doc_id in [{doc_id}, {doc_id + 1}]", # Query documents by ID | |
# output_fields=["seq_id", "vector", "doc"], # Fields to retrieve | |
# limit=1000, # Retrieve a maximum of 1000 vectors per document | |
# ) | |
# # Compute the maximum similarity score for the document | |
# doc_vecs = np.vstack( | |
# [doc_colbert_vecs[i]["vector"] for i in range(len(doc_colbert_vecs))] | |
# ) | |
# score = np.dot(data, doc_vecs.T).max(1).sum() | |
# return (score, doc_id) | |
# # Use multithreading to rerank documents in parallel | |
# with concurrent.futures.ThreadPoolExecutor(max_workers=300) as executor: | |
# futures = { | |
# executor.submit( | |
# rerank_single_doc, doc_id, data, self.client, self.collection_name | |
# ): doc_id | |
# for doc_id in doc_ids | |
# } | |
# for future in concurrent.futures.as_completed(futures): | |
# score, doc_id = future.result() | |
# scores.append((score, doc_id)) | |
# # Sort scores in descending order and return the top-k results | |
# scores.sort(key=lambda x: x[0], reverse=True) | |
# return scores[:topk] if len(scores) >= topk else scores | |
def insert(self, data): | |
""" | |
Insert a batch of data into the collection. | |
Args: | |
data (dict): Dictionary containing vector embeddings and metadata. | |
""" | |
colbert_vecs = [vec for vec in data["colbert_vecs"]] | |
seq_length = len(colbert_vecs) | |
doc_ids = [data["doc_id"] for i in range(seq_length)] | |
seq_ids = list(range(seq_length)) | |
docs = [""] * seq_length | |
docs[0] = data["filepath"] # Store file path in the first entry | |
# Insert the data into the collection | |
self.client.insert( | |
self.collection_name, | |
[ | |
{ | |
"vector": colbert_vecs[i], | |
"seq_id": seq_ids[i], | |
"doc_id": doc_ids[i], | |
"doc": docs[i], | |
} | |
for i in range(seq_length) | |
], | |
) | |
def get_images_as_doc(self, images_with_vectors: list): | |
""" | |
Convert image data with vectors into document-like format for insertion. | |
Args: | |
images_with_vectors (list): List of dictionaries containing image vectors and file paths. | |
Returns: | |
list: Transformed data ready for insertion. | |
""" | |
images_data = [] | |
for i in range(len(images_with_vectors)): | |
data = { | |
"colbert_vecs": images_with_vectors[i]["colbert_vecs"], | |
"doc_id": i, | |
"filepath": images_with_vectors[i]["filepath"], | |
} | |
images_data.append(data) | |
return images_data | |
def insert_images_data(self, image_data): | |
""" | |
Insert processed image data into the collection. | |
Args: | |
image_data (list): List of image data dictionaries. | |
""" | |
data = self.get_images_as_doc(image_data) | |
for i in range(len(data)): | |
self.insert(data[i]) # Insert each item individually | |