|
import base64 |
|
import pickle |
|
from typing import Any, Iterable, List, Optional, Tuple |
|
|
|
from omagent_core.memories.ltms.ltm_base import LTMBase |
|
from omagent_core.services.connectors.milvus import MilvusConnector |
|
from omagent_core.utils.registry import registry |
|
from pydantic import Field |
|
from pymilvus import (Collection, CollectionSchema, DataType, FieldSchema, |
|
utility) |
|
|
|
|
|
@registry.register_component() |
|
class VideoMilvusLTM(LTMBase): |
|
milvus_ltm_client: MilvusConnector |
|
storage_name: str = Field(default="default") |
|
dim: int = Field(default=128) |
|
|
|
def model_post_init(self, __context: Any) -> None: |
|
pass |
|
|
|
def _create_collection(self) -> None: |
|
|
|
if not self.milvus_ltm_client._client.has_collection(self.storage_name): |
|
index_params = self.milvus_ltm_client._client.prepare_index_params() |
|
|
|
key_field = FieldSchema( |
|
name="key", dtype=DataType.VARCHAR, is_primary=True, max_length=256 |
|
) |
|
value_field = FieldSchema( |
|
name="value", dtype=DataType.JSON, description="Json value" |
|
) |
|
embedding_field = FieldSchema( |
|
name="embedding", |
|
dtype=DataType.FLOAT_VECTOR, |
|
description="Embedding vector", |
|
dim=self.dim, |
|
) |
|
index_params = self.milvus_ltm_client._client.prepare_index_params() |
|
|
|
|
|
schema = CollectionSchema( |
|
fields=[key_field, value_field, embedding_field], |
|
description="Key-Value storage with embeddings", |
|
) |
|
for field in schema.fields: |
|
if ( |
|
field.dtype == DataType.FLOAT_VECTOR |
|
or field.dtype == DataType.BINARY_VECTOR |
|
): |
|
index_params.add_index( |
|
field_name=field.name, |
|
index_name=field.name, |
|
index_type="FLAT", |
|
metric_type="COSINE", |
|
params={"nlist": 128}, |
|
) |
|
self.milvus_ltm_client._client.create_collection( |
|
self.storage_name, schema=schema, index_params=index_params |
|
) |
|
|
|
|
|
print(f"Created storage {self.storage_name} successfully") |
|
|
|
def __getitem__(self, key: Any) -> Any: |
|
key_str = str(key) |
|
expr = f'key == "{key_str}"' |
|
res = self.milvus_ltm_client._client.query( |
|
self.storage_name, expr, output_fields=["value"] |
|
) |
|
if res: |
|
value = res[0]["value"] |
|
|
|
|
|
return value |
|
else: |
|
raise KeyError(f"Key {key} not found") |
|
|
|
def __setitem__(self, key: Any, value: Any) -> None: |
|
self._create_collection() |
|
|
|
key_str = str(key) |
|
|
|
|
|
if isinstance(value, dict) and "value" in value and "embedding" in value: |
|
actual_value = value["value"] |
|
embedding = value["embedding"] |
|
else: |
|
raise ValueError( |
|
"When setting an item, value must be a dictionary containing 'value' and 'embedding' keys." |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
if embedding is None: |
|
raise ValueError("An embedding vector must be provided.") |
|
|
|
|
|
if key_str in self: |
|
self.__delitem__(key_str) |
|
|
|
|
|
data = [ |
|
{ |
|
"key": key_str, |
|
"value": actual_value, |
|
"embedding": embedding, |
|
} |
|
] |
|
|
|
|
|
self.milvus_ltm_client._client.insert( |
|
collection_name=self.storage_name, data=data |
|
) |
|
|
|
def __delitem__(self, key: Any) -> None: |
|
key_str = str(key) |
|
if key_str in self: |
|
expr = f'key == "{key_str}"' |
|
self.milvus_ltm_client._client.delete(self.storage_name, expr) |
|
else: |
|
raise KeyError(f"Key {key} not found") |
|
|
|
def __contains__(self, key: Any) -> bool: |
|
key_str = str(key) |
|
expr = f'key == "{key_str}"' |
|
|
|
res = self.milvus_ltm_client._client.query( |
|
self.storage_name, |
|
filter=expr, |
|
output_fields=["key"], |
|
) |
|
return len(res) > 0 |
|
|
|
""" |
|
def __len__(self) -> int: |
|
milvus_ltm.collection.flush() |
|
return self.collection.num_entities |
|
""" |
|
|
|
def __len__(self) -> int: |
|
expr = 'key != ""' |
|
|
|
results = self.milvus_ltm_client._client.query( |
|
self.storage_name, expr, output_fields=["key"], consistency_level="Strong" |
|
) |
|
return len(results) |
|
|
|
def keys(self, limit=10) -> Iterable[Any]: |
|
expr = "" |
|
res = self.milvus_ltm_client._client.query( |
|
self.storage_name, expr, output_fields=["key"], limit=limit |
|
) |
|
return (item["key"] for item in res) |
|
|
|
def values(self) -> Iterable[Any]: |
|
expr = 'key != ""' |
|
self.milvus_ltm_client._client.load(refresh=True) |
|
res = self.milvus_ltm_client._client.query( |
|
self.storage_name, expr, output_fields=["value"], consistency_level="Strong" |
|
) |
|
for item in res: |
|
value_base64 = item["value"] |
|
value_bytes = base64.b64decode(value_base64) |
|
value = pickle.loads(value_bytes) |
|
yield value |
|
|
|
def items(self) -> Iterable[Tuple[Any, Any]]: |
|
expr = 'key != ""' |
|
res = self.milvus_ltm_client._client.query( |
|
self.storage_name, expr, output_fields=["key", "value"] |
|
) |
|
for item in res: |
|
key = item["key"] |
|
value = item["value"] |
|
|
|
|
|
yield (key, value) |
|
|
|
def get(self, key: Any, default: Any = None) -> Any: |
|
try: |
|
return self[key] |
|
except KeyError: |
|
return default |
|
|
|
def clear(self) -> None: |
|
expr = ( |
|
'key != ""' |
|
) |
|
self.milvus_ltm_client._client.delete(self.storage_name, filter=expr) |
|
|
|
def pop(self, key: Any, default: Any = None) -> Any: |
|
try: |
|
value = self[key] |
|
self.__delitem__(key) |
|
return value |
|
except KeyError: |
|
if default is not None: |
|
return default |
|
else: |
|
raise |
|
|
|
def update(self, other: Iterable[Tuple[Any, Any]]) -> None: |
|
for key, value in other: |
|
self[key] = value |
|
|
|
def get_by_vector( |
|
self, |
|
embedding: List[float], |
|
top_k: int = 10, |
|
threshold: float = 0.0, |
|
filter: str = "", |
|
) -> List[Tuple[Any, Any, float]]: |
|
search_params = { |
|
"metric_type": "COSINE", |
|
"params": {"nprobe": 10, "range_filter": 1, "radius": threshold}, |
|
} |
|
results = self.milvus_ltm_client._client.search( |
|
self.storage_name, |
|
data=[embedding], |
|
anns_field="embedding", |
|
search_params=search_params, |
|
limit=top_k, |
|
output_fields=["key", "value"], |
|
consistency_level="Strong", |
|
filter=filter, |
|
) |
|
|
|
items = [] |
|
for match in results[0]: |
|
key = match.get("entity").get("key") |
|
value = match.get("entity").get("value") |
|
items.append((key, value)) |
|
|
|
return items |
|
|