Spaces:
Sleeping
Sleeping
from typing import Any | |
from uuid import uuid4 | |
import numpy as np | |
from omagent_core.utils.error import VQLError | |
from omagent_core.utils.registry import registry | |
from pydantic import BaseModel | |
from pymilvus import Collection, DataType, MilvusClient, connections, utility | |
from pymilvus.client import types | |
class MilvusHandler(BaseModel): | |
host_url: str = "./memory.db" | |
user: str = "" | |
password: str = "" | |
db_name: str = "default" | |
primary_field: Any = None | |
vector_field: Any = None | |
class Config: | |
"""Configuration for this pydantic object.""" | |
extra = "allow" | |
arbitrary_types_allowed = True | |
def __init__(self, **data: Any): | |
super().__init__(**data) | |
self.milvus_client = MilvusClient( | |
uri=self.host_url, | |
user=self.user, | |
password=self.password, | |
db_name=self.db_name, | |
) | |
def is_collection_in(self, collection_name): | |
""" | |
Check if a collection exists in Milvus. | |
Args: | |
collection_name (str): The name of the collection to check. | |
Returns: | |
bool: True if the collection exists, False otherwise. | |
""" | |
return self.milvus_client.has_collection(collection_name) | |
def make_collection(self, collection_name, schema): | |
""" | |
Create a new collection in Milvus. | |
This method will first check if a collection with the given name already exists. | |
If it does, it will print a message and do nothing. | |
If it doesn't, it will create a new collection with the given name and schema, | |
and then create an index for the vector field in the collection. | |
Args: | |
collection_name (str): The name of the collection to create. | |
schema (CollectionSchema): The schema of the collection to create. | |
Raises: | |
VQLError: If the schema does not have exactly one primary key. | |
""" | |
index_params = self.milvus_client.prepare_index_params() | |
for field in schema.fields: | |
if ( | |
field.dtype == DataType.FLOAT_VECTOR | |
or field.dtype == DataType.BINARY_VECTOR | |
): | |
index_params.add_index( | |
field_name=field.name, | |
index_name=field.name, | |
index_type="FLAT", | |
metric_type="COSINE", | |
params={"nlist": 128}, | |
) | |
print(f"{field.name} of {collection_name} index created") | |
if self.is_collection_in(collection_name): | |
print(f"{collection_name} collection already exists") | |
else: | |
self.milvus_client.create_collection( | |
collection_name, schema=schema, index_params=index_params | |
) | |
print(f"Create collection {collection_name} successfully") | |
def drop_collection(self, collection_name): | |
""" | |
Drop a collection in Milvus. | |
This method will first check if a collection with the given name exists. | |
If it does, it will drop the collection and print a success message. | |
If it doesn't, it will print a message indicating that the collection does not exist. | |
Args: | |
collection_name (str): The name of the collection to drop. | |
""" | |
if self.is_collection_in(collection_name): | |
self.milvus_client.drop_collection(collection_name) | |
print(f"Drop collection {collection_name} successfully") | |
else: | |
print(f"{collection_name} collection does not exist") | |
def do_add(self, collection_name, vectors): | |
""" | |
Add vectors to a collection in Milvus. | |
This method will first check if a collection with the given name exists. | |
If it does, it will add the vectors to the collection and return the IDs of the added vectors. | |
If it doesn't, it will raise a VQLError. | |
Args: | |
collection_name (str): The name of the collection to add vectors to. | |
vectors (list): The vectors to add to the collection. | |
Returns: | |
list: The IDs of the added vectors. | |
Raises: | |
VQLError: If the collection does not exist. | |
""" | |
if self.is_collection_in(collection_name): | |
res = self.milvus_client.insert( | |
collection_name=collection_name, data=vectors | |
) | |
return res["ids"] | |
else: | |
raise VQLError(500, detail=f"{collection_name} collection does not exist") | |
def match( | |
self, | |
collection_name, | |
query_vectors: list, | |
query_field, | |
output_fields: list = None, | |
res_size=10, | |
filter_expr="", | |
threshold=0, | |
): | |
""" | |
Perform a vector similarity search in a specified collection in Milvus. | |
This method will first check if a collection with the given name exists. | |
If it does, it will perform a vector similarity search using the provided query vectors, | |
and return the search results. | |
If it doesn't, it will raise a VQLError. | |
Args: | |
collection_name (str): The name of the collection to search in. | |
query_vectors (list): The vectors to use as query for the search. | |
query_field (str): The field to perform the search on. | |
output_fields (list): The fields to include in the search results. | |
res_size (int): The maximum number of search results to return. | |
filter_expr (str): The filter expression to apply during the search. | |
threshold (float): The threshold for the similarity search. | |
Returns: | |
list: The search results. | |
Raises: | |
VQLError: If the collection does not exist. | |
""" | |
if self.is_collection_in(collection_name): | |
search_params = { | |
"metric_type": "COSINE", | |
"ignore_growing": False, | |
"params": { | |
"nprobe": 10, | |
"radius": 2 * threshold - 1, | |
"range_filter": 1, | |
}, | |
} | |
hits = self.milvus_client.search( | |
collection_name=collection_name, | |
data=query_vectors, | |
anns_field=query_field, | |
search_params=search_params, | |
limit=res_size, | |
output_fields=output_fields, | |
filter=filter_expr, | |
) | |
return hits | |
else: | |
raise VQLError(500, detail=f"{collection_name} collection does not exist") | |
def delete_doc_by_ids(self, collection_name, ids): | |
""" | |
Delete specific documents in a collection in Milvus by their IDs. | |
This method will first check if a collection with the given name exists. | |
If it does, it will delete the documents with the provided IDs from the collection. | |
If it doesn't, it will raise a VQLError. | |
Args: | |
collection_name (str): The name of the collection to delete documents from. | |
ids (list): The IDs of the documents to delete. | |
Raises: | |
VQLError: If the collection does not exist. | |
""" | |
if self.is_collection_in(collection_name): | |
delete_expr = f"{self.primary_field} in {ids}" | |
res = self.milvus_client.delete( | |
collection_name=collection_name, filter=delete_expr | |
) | |
return res | |
else: | |
raise VQLError(500, detail=f"{collection_name} collection does not exist") | |
def delete_doc_by_expr(self, collection_name, expr): | |
""" | |
Delete specific documents in a collection in Milvus by an expression. | |
This method will first check if a collection with the given name exists. | |
If it does, it will delete the documents that match the provided expression from the collection. | |
If it doesn't, it will raise a VQLError. | |
Args: | |
collection_name (str): The name of the collection to delete documents from. | |
expr (str): The expression to match the documents to delete. | |
Raises: | |
VQLError: If the collection does not exist. | |
""" | |
if self.is_collection_in(collection_name): | |
self.milvus_client.delete(collection_name=collection_name, filter=expr) | |
else: | |
raise VQLError(500, detail=f"{collection_name} collection does not exist") | |
if __name__ == "__main__": | |
from pymilvus import CollectionSchema, DataType, FieldSchema | |
milvus_handler = MilvusHandler() | |
rng = np.random.default_rng() | |
pk = FieldSchema( | |
name="pk", | |
dtype=DataType.VARCHAR, | |
is_primary=True, | |
auto_id=False, | |
max_length=100, | |
) | |
bot_id = FieldSchema(name="bot_id", dtype=DataType.VARCHAR, max_length=50) | |
vector = FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=512) | |
schema = CollectionSchema( | |
fields=[pk, bot_id, vector], | |
description="this is test", | |
) | |
data = [ | |
{ | |
"pk": str(uuid4()), | |
"bot_id": str(uuid4()), | |
# rng.random((1, 512)) | |
"vector": [1.0, 2.0] * 256, | |
} | |
] | |
milvus_handler.drop_collection("test1") | |
milvus_handler.make_collection("test1", schema) | |
add_detail = milvus_handler.do_add("test1", data) | |
print(add_detail) | |
print(milvus_handler.milvus_client.describe_index("test1", "vector")) | |
test_data = [[1.0, 2.0] * 256, [100, 400] * 256] | |
match_result = milvus_handler.match( | |
"test1", test_data, "vector", ["pk"], 10, "", 0.65 | |
) | |
print(match_result) | |
# milvus_handler.primary_field = "pk" | |
# milvus_handler.delete_doc_by_ids("test1", ["1f764837-b80b-4788-ad8c-7a89924e343b"]) | |