Draken007's picture
Upload 7228 files
2a0bc63 verified
"""
Compatibility layer for legacy VectorTable (used by LangChain integration
(as of August 2023).
Note: This is to be replaced by direct usage of the table-class-hierarchy classes.
"""
from typing import Any, Dict, List, Optional
from warnings import warn
from cassandra.cluster import ResponseFuture
from cassio.table.table_types import RowType
from cassio.table.tables import MetadataVectorCassandraTable
new_columns_to_legacy = {
"row_id": "document_id",
"body_blob": "document",
"vector": "embedding_vector",
}
legacy_columns_to_new = {v: k for k, v in new_columns_to_legacy.items()}
class VectorTable:
"""
This class is a rewriting of the VectorTable created for use in LangChain
integration, this time relying on the class-table-hierarchy (cassio.table.*).
It mostly provides a translation layer between parameters and key names,
using a metadata+vector table class internally.
Additional kwargs, for use in this new table class, are passed as they are
in order to enable their usage already before adapting the LangChain
integration code.
"""
DEPRECATION_MESSAGE = (
"Class `VectorTable` is a legacy construct and "
"will be deprecated in future versions of CassIO."
)
def __init__(self, *pargs: Any, **kwargs: Any):
#
warn(self.DEPRECATION_MESSAGE, DeprecationWarning, stacklevel=2)
#
if "embedding_dimension" in kwargs:
vector_dimension = kwargs["embedding_dimension"]
new_kwargs = {
**{
k: v
for k, v in kwargs.items()
if k != "embedding_dimension"
# let's get rid of the infamous 'auto_id' here:
if k != "auto_id"
},
**{"vector_dimension": vector_dimension},
}
else:
new_kwargs = kwargs
# this legacy VectorTable will have everything indexed for search:
md_kwargs = {
**{"metadata_indexing": "all"},
**new_kwargs,
}
#
self.table = MetadataVectorCassandraTable(*pargs, **md_kwargs)
def search(
self,
embedding_vector: List[float],
top_k: int,
metric: str = "cos",
metric_threshold: Optional[float] = None,
**kwargs: Any,
) -> List[RowType]:
# get rows by ANN
enriched_hits = self.table.metric_ann_search(
vector=embedding_vector,
n=top_k,
metric=metric,
metric_threshold=metric_threshold,
**kwargs,
)
#
return [self._make_dict_legacy(rich_hit) for rich_hit in enriched_hits]
def put(
self,
document: str,
embedding_vector: List[float],
document_id: Any,
metadata: Optional[Dict[str, Any]] = None,
ttl_seconds: Optional[int] = None,
**kwargs: Any,
) -> None:
self.table.put(
row_id=document_id,
body_blob=document,
vector=embedding_vector,
metadata=metadata or {},
ttl_seconds=ttl_seconds,
**kwargs,
)
def put_async(
self,
document: str,
embedding_vector: List[float],
document_id: Any,
metadata: Optional[Dict[str, Any]] = None,
ttl_seconds: Optional[int] = None,
**kwargs: Any,
) -> ResponseFuture:
return self.table.put_async(
row_id=document_id,
body_blob=document,
vector=embedding_vector,
metadata=metadata or {},
ttl_seconds=ttl_seconds,
**kwargs,
)
def get(self, document_id: Any, **kwargs: Any) -> Optional[RowType]:
row_or_none = self.table.get(row_id=document_id, **kwargs)
if row_or_none:
return self._make_dict_legacy(row_or_none)
else:
return row_or_none
def delete(self, document_id: Any, **kwargs: Any) -> None:
self.table.delete(row_id=document_id, **kwargs)
return None
def clear(self) -> None:
self.table.clear()
return None
@staticmethod
def _make_dict_legacy(new_dict: RowType) -> RowType:
return {new_columns_to_legacy.get(k, k): v for k, v in new_dict.items()}