Draken007's picture
Upload 7228 files
2a0bc63 verified
import inspect
from operator import itemgetter
from typing import Any, Awaitable, Iterable, List, Optional, Tuple, Union
from cassandra.cluster import ResponseFuture
from cassio.table.base_table import BaseTable
from cassio.table.cql import SELECT_ANN_CQL_TEMPLATE, CQLOpType
from cassio.table.table_types import ColumnSpecType, RowType, RowWithDistanceType
from cassio.utils.vector.distance_metrics import distance_metrics
from .base_table import BaseTableMixin
class VectorMixin(BaseTableMixin):
def __init__(
self,
*pargs: Any,
vector_dimension: Union[int, Awaitable[int]],
vector_similarity_function: Optional[str] = None,
vector_source_model: Optional[str] = None,
**kwargs: Any,
) -> None:
if inspect.isawaitable(vector_dimension) and not kwargs.get(
"async_setup", False
):
raise ValueError(
"Cannot use an awaitable embedding_dimension "
"with async_setup set to False"
)
self.vector_dimension = vector_dimension
self.vector_index_options = []
if vector_similarity_function is not None:
self.vector_index_options.append(
("similarity_function", vector_similarity_function)
)
if vector_source_model is not None:
self.vector_index_options.append(("source_model", vector_source_model))
super().__init__(*pargs, **kwargs)
def _schema_da(self) -> List[ColumnSpecType]:
return super()._schema_da() + [
("vector", f"VECTOR<FLOAT,{self.vector_dimension}>")
]
async def _aschema_da(self) -> List[ColumnSpecType]:
if inspect.isawaitable(self.vector_dimension):
self.vector_dimension = await self.vector_dimension
return self._schema_da()
@staticmethod
def _get_create_vector_index_cql(
vector_index_options: List[Tuple[str, Any]]
) -> str:
index_name = "idx_vector"
index_column = "vector"
return BaseTable._get_create_index_cql(
index_name=index_name,
index_column=index_column,
index_options=vector_index_options,
)
def db_setup(self) -> None:
super().db_setup()
# index on the vector column:
create_index_cql = self._get_create_vector_index_cql(self.vector_index_options)
self.execute_cql(create_index_cql, op_type=CQLOpType.SCHEMA)
async def adb_setup(self) -> None:
await super().adb_setup()
# index on the vector column:
create_index_cql = self._get_create_vector_index_cql(self.vector_index_options)
await self.aexecute_cql(create_index_cql, op_type=CQLOpType.SCHEMA)
def _get_ann_search_cql(
self, vector: List[float], n: int, **kwargs: Any
) -> Tuple[str, Tuple[Any, ...]]:
n_kwargs = self._normalize_kwargs(kwargs)
# TODO: work on a columns: Optional[List[str]] = None
# (but with nuanced handling of the column-magic we have here)
columns = None
if columns is None:
columns_desc = "*"
else:
# TODO: handle translations here?
# columns_desc = ", ".join(columns)
raise NotImplementedError("Column selection is not implemented.")
#
if all(x == 0 for x in vector):
# TODO: lift/relax this constraint when non-cosine metrics are there.
raise ValueError("Cannot use identically-zero vectors in cos/ANN search.")
#
vector_column = "vector"
vector_cql_vals = (vector,)
#
(
rest_kwargs,
where_clause_blocks,
where_cql_vals,
) = self._extract_where_clause_blocks(n_kwargs)
(
rest_kwargs,
analyzer_clause_blocks,
analyzer_cql_vals,
) = self._extract_index_analyzers(rest_kwargs)
assert rest_kwargs == {}
all_where_clauses = where_clause_blocks + analyzer_clause_blocks
if not all_where_clauses:
where_clause = ""
else:
where_clause = "WHERE " + " AND ".join(all_where_clauses)
#
limit_clause = "LIMIT %s"
limit_cql_vals = (n,)
#
select_ann_cql = SELECT_ANN_CQL_TEMPLATE.format(
columns_desc=columns_desc,
vector_column=vector_column,
where_clause=where_clause,
limit_clause=limit_clause,
)
#
select_ann_cql_vals = (
where_cql_vals + analyzer_cql_vals + vector_cql_vals + limit_cql_vals
)
return select_ann_cql, select_ann_cql_vals
def ann_search(
self, vector: List[float], n: int, **kwargs: Any
) -> Iterable[RowType]:
select_ann_cql, select_ann_cql_vals = self._get_ann_search_cql(
vector, n, **kwargs
)
result_set = self.execute_cql(
select_ann_cql, args=select_ann_cql_vals, op_type=CQLOpType.READ
)
return (self._normalize_row(result) for result in result_set)
def ann_search_async(
self, vector: List[float], n: int, **kwargs: Any
) -> ResponseFuture:
raise NotImplementedError("Asynchronous reads are not supported.")
async def aann_search(
self, vector: List[float], n: int, **kwargs: Any
) -> Iterable[RowType]:
select_ann_cql, select_ann_cql_vals = self._get_ann_search_cql(
vector, n, **kwargs
)
result_set = await self.aexecute_cql(
select_ann_cql, args=select_ann_cql_vals, op_type=CQLOpType.READ
)
return (self._normalize_row(result) for result in result_set)
@staticmethod
def _get_rows_with_distance(
rows: Iterable[RowType],
vector: List[float],
metric: str,
metric_threshold: Optional[float] = None,
) -> Iterable[RowWithDistanceType]:
if rows == []:
return []
else:
# sort, cut, validate and prepare for returning
# evaluate metric
distance_function, distance_reversed = distance_metrics[metric]
row_vectors = [row["vector"] for row in rows]
# enrich with their metric score
rows_with_metric = list(
zip(
distance_function(row_vectors, vector),
rows,
)
)
# sort rows by metric score. First handle metric/threshold
if metric_threshold is not None:
_used_thr = metric_threshold
if distance_reversed:
def _thresholder(mtx: float, thr: float) -> bool:
return mtx >= thr
else:
def _thresholder(mtx: float, thr: float) -> bool:
return mtx <= thr
else:
# this to satisfy the type checker
_used_thr = 0.0
# no hits are discarded
def _thresholder(mtx: float, thr: float) -> bool:
return True
#
sorted_passing_rows = sorted(
(pair for pair in rows_with_metric if _thresholder(pair[0], _used_thr)),
key=itemgetter(0),
reverse=distance_reversed,
)
# return a list of hits with their distance (as JSON)
enriched_hits = (
{
**hit,
**{"distance": distance},
}
for distance, hit in sorted_passing_rows
)
return enriched_hits
def metric_ann_search(
self,
vector: List[float],
n: int,
metric: str,
metric_threshold: Optional[float] = None,
**kwargs: Any,
) -> Iterable[RowWithDistanceType]:
rows = list(self.ann_search(vector, n, **kwargs))
return self._get_rows_with_distance(rows, vector, metric, metric_threshold)
def metric_ann_search_async(
self, vector: List[float], n: int, **kwargs: Any
) -> ResponseFuture:
raise NotImplementedError("Asynchronous reads are not supported.")
async def ametric_ann_search(
self,
vector: List[float],
n: int,
metric: str,
metric_threshold: Optional[float] = None,
**kwargs: Any,
) -> Iterable[RowWithDistanceType]:
rows = list(await self.aann_search(vector, n, **kwargs))
return self._get_rows_with_distance(rows, vector, metric, metric_threshold)