Spaces:
Build error
Build error
use std::borrow::Cow; | |
use std::collections::HashSet; | |
use std::mem::take; | |
use api::rest::LookupLocation; | |
use collection::collection::distance_matrix::CollectionSearchMatrixRequest; | |
use collection::grouping::group_by::{GroupRequest, SourceRequest}; | |
use collection::lookup::WithLookup; | |
use collection::operations::payload_ops::{DeletePayloadOp, PayloadOps, SetPayloadOp}; | |
use collection::operations::point_ops::{PointIdsList, PointOperations}; | |
use collection::operations::types::{ | |
ContextExamplePair, CoreSearchRequest, CountRequestInternal, DiscoverRequestInternal, | |
PointRequestInternal, RecommendExample, RecommendRequestInternal, ScrollRequestInternal, | |
}; | |
use collection::operations::universal_query::collection_query::{ | |
CollectionPrefetch, CollectionQueryRequest, Query, VectorInputInternal, VectorQuery, | |
}; | |
use collection::operations::vector_ops::VectorOperations; | |
use collection::operations::CollectionUpdateOperations; | |
use segment::data_types::facets::FacetParams; | |
use segment::types::{Condition, ExtendedPointId, FieldCondition, Filter, Match, Payload}; | |
use super::{ | |
incompatible_with_payload_constraint, Access, AccessRequirements, CollectionAccessList, | |
CollectionAccessView, CollectionPass, PayloadConstraint, | |
}; | |
use crate::content_manager::collection_meta_ops::CollectionMetaOperations; | |
use crate::content_manager::errors::{StorageError, StorageResult}; | |
impl Access { | |
pub(crate) fn check_point_op<'a>( | |
&self, | |
collection_name: &'a str, | |
op: &mut impl CheckableCollectionOperation, | |
) -> Result<CollectionPass<'a>, StorageError> { | |
let requirements = op.access_requirements(); | |
match self { | |
Access::Global(mode) => mode.meets_requirements(requirements)?, | |
Access::Collection(list) => { | |
let view = list.find_view(collection_name)?; | |
view.meets_requirements(requirements)?; | |
op.check_access(view, list)?; | |
} | |
} | |
Ok(CollectionPass(Cow::Borrowed(collection_name))) | |
} | |
pub(crate) fn check_collection_meta_operation( | |
&self, | |
operation: &CollectionMetaOperations, | |
) -> Result<(), StorageError> { | |
match operation { | |
CollectionMetaOperations::CreateCollection(_) | |
| CollectionMetaOperations::UpdateCollection(_) | |
| CollectionMetaOperations::DeleteCollection(_) | |
| CollectionMetaOperations::ChangeAliases(_) | |
| CollectionMetaOperations::Resharding(_, _) | |
| CollectionMetaOperations::TransferShard(_, _) | |
| CollectionMetaOperations::SetShardReplicaState(_) | |
| CollectionMetaOperations::CreateShardKey(_) | |
| CollectionMetaOperations::DropShardKey(_) => { | |
self.check_global_access(AccessRequirements::new().manage())?; | |
} | |
CollectionMetaOperations::CreatePayloadIndex(op) => { | |
self.check_collection_access( | |
&op.collection_name, | |
AccessRequirements::new().write().whole(), | |
)?; | |
} | |
CollectionMetaOperations::DropPayloadIndex(op) => { | |
self.check_collection_access( | |
&op.collection_name, | |
AccessRequirements::new().write().whole(), | |
)?; | |
} | |
CollectionMetaOperations::Nop { token: _ } => (), | |
} | |
Ok(()) | |
} | |
} | |
trait CheckableCollectionOperation { | |
/// Used to distinguish whether the operation is read-only or read-write. | |
fn access_requirements(&self) -> AccessRequirements; | |
fn check_access( | |
&mut self, | |
view: CollectionAccessView<'_>, | |
access: &CollectionAccessList, | |
) -> Result<(), StorageError>; | |
} | |
impl CollectionAccessList { | |
fn check_lookup_from( | |
&self, | |
lookup_location: &Option<LookupLocation>, | |
) -> Result<(), StorageError> { | |
if let Some(lookup_location) = lookup_location { | |
self.find_view(&lookup_location.collection)? | |
.check_whole_access()?; | |
} | |
Ok(()) | |
} | |
fn check_with_lookup(&self, with_lookup: &Option<WithLookup>) -> Result<(), StorageError> { | |
if let Some(with_lookup) = with_lookup { | |
self.find_view(&with_lookup.collection_name)? | |
.check_whole_access()?; | |
} | |
Ok(()) | |
} | |
} | |
impl<'a> CollectionAccessView<'a> { | |
fn apply_filter(&self, filter: &mut Option<Filter>) { | |
if let Some(payload) = &self.payload { | |
let f = filter.get_or_insert_with(Default::default); | |
*f = take(f).merge_owned(payload.to_filter()); | |
} | |
} | |
fn check_recommend_example(&self, example: &RecommendExample) -> Result<(), StorageError> { | |
match example { | |
RecommendExample::PointId(_) => self.check_whole_access(), | |
RecommendExample::Dense(_) | RecommendExample::Sparse(_) => Ok(()), | |
} | |
} | |
fn check_vector_query( | |
&self, | |
vector_query: &VectorQuery<VectorInputInternal>, | |
) -> Result<(), StorageError> { | |
match vector_query { | |
VectorQuery::Nearest(nearest) => self.check_vector_input(nearest)?, | |
VectorQuery::RecommendBestScore(reco) | VectorQuery::RecommendAverageVector(reco) => { | |
for vector_input in reco.flat_iter() { | |
self.check_vector_input(vector_input)? | |
} | |
} | |
VectorQuery::Discover(discover) => { | |
for vector_input in discover.flat_iter() { | |
self.check_vector_input(vector_input)? | |
} | |
} | |
VectorQuery::Context(context) => { | |
for vector_input in context.flat_iter() { | |
self.check_vector_input(vector_input)? | |
} | |
} | |
}; | |
Ok(()) | |
} | |
fn check_vector_input(&self, vector_input: &VectorInputInternal) -> Result<(), StorageError> { | |
match vector_input { | |
VectorInputInternal::Vector(_) => Ok(()), | |
VectorInputInternal::Id(_) => self.check_whole_access(), | |
} | |
} | |
} | |
impl CheckableCollectionOperation for RecommendRequestInternal { | |
fn access_requirements(&self) -> AccessRequirements { | |
AccessRequirements { | |
write: false, | |
manage: false, | |
whole: false, | |
} | |
} | |
fn check_access( | |
&mut self, | |
view: CollectionAccessView<'_>, | |
access: &CollectionAccessList, | |
) -> Result<(), StorageError> { | |
for e in &self.positive { | |
view.check_recommend_example(e)?; | |
} | |
for e in &self.negative { | |
view.check_recommend_example(e)?; | |
} | |
access.check_lookup_from(&self.lookup_from)?; | |
view.apply_filter(&mut self.filter); | |
Ok(()) | |
} | |
} | |
impl CheckableCollectionOperation for PointRequestInternal { | |
fn access_requirements(&self) -> AccessRequirements { | |
AccessRequirements { | |
write: false, | |
manage: false, | |
whole: true, | |
} | |
} | |
fn check_access( | |
&mut self, | |
_view: CollectionAccessView<'_>, | |
_access: &CollectionAccessList, | |
) -> Result<(), StorageError> { | |
Ok(()) | |
} | |
} | |
impl CheckableCollectionOperation for CoreSearchRequest { | |
fn access_requirements(&self) -> AccessRequirements { | |
AccessRequirements { | |
write: false, | |
manage: false, | |
whole: false, | |
} | |
} | |
fn check_access( | |
&mut self, | |
view: CollectionAccessView<'_>, | |
_access: &CollectionAccessList, | |
) -> Result<(), StorageError> { | |
view.apply_filter(&mut self.filter); | |
Ok(()) | |
} | |
} | |
impl CheckableCollectionOperation for CountRequestInternal { | |
fn access_requirements(&self) -> AccessRequirements { | |
AccessRequirements { | |
write: false, | |
manage: false, | |
whole: false, | |
} | |
} | |
fn check_access( | |
&mut self, | |
view: CollectionAccessView<'_>, | |
_access: &CollectionAccessList, | |
) -> Result<(), StorageError> { | |
view.apply_filter(&mut self.filter); | |
Ok(()) | |
} | |
} | |
impl CheckableCollectionOperation for GroupRequest { | |
fn access_requirements(&self) -> AccessRequirements { | |
AccessRequirements { | |
write: false, | |
manage: false, | |
whole: false, | |
} | |
} | |
fn check_access( | |
&mut self, | |
view: CollectionAccessView<'_>, | |
access: &CollectionAccessList, | |
) -> Result<(), StorageError> { | |
match &mut self.source { | |
SourceRequest::Search(s) => { | |
view.apply_filter(&mut s.filter); | |
} | |
SourceRequest::Recommend(r) => r.check_access(view, access)?, | |
SourceRequest::Query(q) => q.check_access(view, access)?, | |
} | |
access.check_with_lookup(&self.with_lookup)?; | |
Ok(()) | |
} | |
} | |
impl CheckableCollectionOperation for DiscoverRequestInternal { | |
fn access_requirements(&self) -> AccessRequirements { | |
AccessRequirements { | |
write: false, | |
manage: false, | |
whole: false, | |
} | |
} | |
fn check_access( | |
&mut self, | |
view: CollectionAccessView<'_>, | |
access: &CollectionAccessList, | |
) -> Result<(), StorageError> { | |
if let Some(target) = &self.target { | |
view.check_recommend_example(target)?; | |
} | |
for ContextExamplePair { positive, negative } in self.context.iter().flat_map(|c| c.iter()) | |
{ | |
view.check_recommend_example(positive)?; | |
view.check_recommend_example(negative)?; | |
} | |
view.apply_filter(&mut self.filter); | |
access.check_lookup_from(&self.lookup_from)?; | |
Ok(()) | |
} | |
} | |
impl CheckableCollectionOperation for ScrollRequestInternal { | |
fn access_requirements(&self) -> AccessRequirements { | |
AccessRequirements { | |
write: false, | |
manage: false, | |
whole: false, | |
} | |
} | |
fn check_access( | |
&mut self, | |
view: CollectionAccessView<'_>, | |
_access: &CollectionAccessList, | |
) -> Result<(), StorageError> { | |
view.apply_filter(&mut self.filter); | |
Ok(()) | |
} | |
} | |
impl CheckableCollectionOperation for CollectionQueryRequest { | |
fn access_requirements(&self) -> AccessRequirements { | |
AccessRequirements { | |
write: false, | |
manage: false, | |
whole: false, | |
} | |
} | |
fn check_access( | |
&mut self, | |
view: CollectionAccessView<'_>, | |
access: &CollectionAccessList, | |
) -> Result<(), StorageError> { | |
view.apply_filter(&mut self.filter); | |
if let Some(Query::Vector(vector_query)) = &self.query { | |
view.check_vector_query(vector_query)? | |
} | |
access.check_lookup_from(&self.lookup_from)?; | |
for prefetch_query in self.prefetch.iter_mut() { | |
check_access_for_prefetch(prefetch_query, &view, access)?; | |
} | |
Ok(()) | |
} | |
} | |
fn check_access_for_prefetch( | |
prefetch: &mut CollectionPrefetch, | |
view: &CollectionAccessView<'_>, | |
access: &CollectionAccessList, | |
) -> Result<(), StorageError> { | |
view.apply_filter(&mut prefetch.filter); | |
if let Some(Query::Vector(vector_query)) = &prefetch.query { | |
view.check_vector_query(vector_query)? | |
} | |
access.check_lookup_from(&prefetch.lookup_from)?; | |
// Recurse inner prefetches | |
for prefetch_query in prefetch.prefetch.iter_mut() { | |
check_access_for_prefetch(prefetch_query, view, access)?; | |
} | |
Ok(()) | |
} | |
impl CheckableCollectionOperation for FacetParams { | |
fn access_requirements(&self) -> AccessRequirements { | |
AccessRequirements { | |
write: false, | |
manage: false, | |
whole: false, | |
} | |
} | |
fn check_access( | |
&mut self, | |
view: CollectionAccessView<'_>, | |
_access: &CollectionAccessList, | |
) -> StorageResult<()> { | |
view.apply_filter(&mut self.filter); | |
Ok(()) | |
} | |
} | |
impl CheckableCollectionOperation for CollectionSearchMatrixRequest { | |
fn access_requirements(&self) -> AccessRequirements { | |
AccessRequirements { | |
write: false, | |
manage: false, | |
whole: false, | |
} | |
} | |
fn check_access( | |
&mut self, | |
view: CollectionAccessView<'_>, | |
_access: &CollectionAccessList, | |
) -> StorageResult<()> { | |
view.apply_filter(&mut self.filter); | |
Ok(()) | |
} | |
} | |
impl CheckableCollectionOperation for CollectionUpdateOperations { | |
fn access_requirements(&self) -> AccessRequirements { | |
match self { | |
CollectionUpdateOperations::PointOperation(_) | |
| CollectionUpdateOperations::VectorOperation(_) | |
| CollectionUpdateOperations::PayloadOperation(_) => AccessRequirements { | |
write: true, | |
manage: false, | |
whole: false, // Checked in `check_access()` | |
}, | |
CollectionUpdateOperations::FieldIndexOperation(_) => AccessRequirements { | |
write: true, | |
manage: true, | |
whole: true, | |
}, | |
} | |
} | |
fn check_access( | |
&mut self, | |
view: CollectionAccessView<'_>, | |
_access: &CollectionAccessList, | |
) -> Result<(), StorageError> { | |
match self { | |
CollectionUpdateOperations::PointOperation(op) => match op { | |
PointOperations::UpsertPoints(_) => { | |
view.check_whole_access()?; | |
} | |
PointOperations::DeletePoints { ids } => { | |
if let Some(payload) = &view.payload { | |
*op = PointOperations::DeletePointsByFilter( | |
make_filter_from_ids(take(ids)).merge_owned(payload.to_filter()), | |
); | |
} | |
} | |
PointOperations::DeletePointsByFilter(filter) => { | |
if let Some(payload) = &view.payload { | |
*filter = take(filter).merge_owned(payload.to_filter()); | |
} | |
} | |
PointOperations::SyncPoints(_) => { | |
view.check_whole_access()?; | |
} | |
}, | |
CollectionUpdateOperations::VectorOperation(op) => match op { | |
VectorOperations::UpdateVectors(_) => { | |
view.check_whole_access()?; | |
} | |
VectorOperations::DeleteVectors(PointIdsList { points, shard_key }, vectors) => { | |
if let Some(payload) = &view.payload { | |
if shard_key.is_some() { | |
// It is unclear where to put the shard_key | |
return incompatible_with_payload_constraint(view.collection); | |
} | |
*op = VectorOperations::DeleteVectorsByFilter( | |
make_filter_from_ids(take(points)).merge_owned(payload.to_filter()), | |
take(vectors), | |
); | |
} | |
} | |
VectorOperations::DeleteVectorsByFilter(filter, _) => { | |
if let Some(payload) = &view.payload { | |
*filter = take(filter).merge_owned(payload.to_filter()); | |
} | |
} | |
}, | |
CollectionUpdateOperations::PayloadOperation(op) => 'a: { | |
let Some(payload) = &view.payload else { | |
// Allow all operations when there is no payload constraint | |
break 'a; | |
}; | |
match op { | |
PayloadOps::SetPayload(SetPayloadOp { | |
payload: _, // TODO: validate | |
points, | |
filter, | |
key: _, // TODO: validate | |
}) => { | |
let filter = filter.get_or_insert_with(Default::default); | |
if let Some(points) = take(points) { | |
*filter = take(filter).merge_owned(make_filter_from_ids(points)); | |
} | |
// Reject as not implemented | |
return incompatible_with_payload_constraint(view.collection); | |
} | |
PayloadOps::DeletePayload(DeletePayloadOp { | |
keys: _, // TODO: validate | |
points, | |
filter, | |
}) => { | |
let filter = filter.get_or_insert_with(Default::default); | |
if let Some(points) = take(points) { | |
*filter = take(filter).merge_owned(make_filter_from_ids(points)); | |
} | |
// Reject as not implemented | |
return incompatible_with_payload_constraint(view.collection); | |
} | |
PayloadOps::ClearPayload { points } => { | |
*op = PayloadOps::OverwritePayload(SetPayloadOp { | |
payload: payload.make_payload(view.collection)?, | |
points: None, | |
filter: Some( | |
make_filter_from_ids(take(points)).merge_owned(payload.to_filter()), | |
), | |
key: None, | |
}); | |
} | |
PayloadOps::ClearPayloadByFilter(filter) => { | |
*op = PayloadOps::OverwritePayload(SetPayloadOp { | |
payload: payload.make_payload(view.collection)?, | |
points: None, | |
filter: Some(take(filter).merge_owned(payload.to_filter())), | |
key: None, | |
}); | |
} | |
PayloadOps::OverwritePayload(SetPayloadOp { | |
payload: _, // TODO: validate | |
points, | |
filter, | |
key: _, // TODO: validate | |
}) => { | |
let filter = filter.get_or_insert_with(Default::default); | |
if let Some(points) = take(points) { | |
*filter = take(filter).merge_owned(make_filter_from_ids(points)); | |
} | |
// Reject as not implemented | |
return incompatible_with_payload_constraint(view.collection); | |
} | |
} | |
} | |
CollectionUpdateOperations::FieldIndexOperation(_) => (), | |
} | |
Ok(()) | |
} | |
} | |
/// Create a `must` filter from a list of point IDs. | |
fn make_filter_from_ids(ids: Vec<ExtendedPointId>) -> Filter { | |
let cond = ids.into_iter().collect::<HashSet<_>>().into(); | |
Filter { | |
must: Some(vec![Condition::HasId(cond)]), | |
..Default::default() | |
} | |
} | |
impl PayloadConstraint { | |
/// Create a `must` filter. | |
fn to_filter(&self) -> Filter { | |
Filter { | |
must: Some( | |
self.0 | |
.iter() | |
.map(|(path, value)| { | |
Condition::Field(FieldCondition::new_match( | |
path.clone(), | |
Match::new_value(value.clone()), | |
)) | |
}) | |
.collect(), | |
), | |
..Default::default() | |
} | |
} | |
fn make_payload(&self, collection_name: &str) -> Result<Payload, StorageError> { | |
let _ = self; // TODO: We need to construct a payload, then validate it against the claim | |
incompatible_with_payload_constraint(collection_name) // Reject as not implemented | |
} | |
} | |
mod tests { | |
use std::collections::HashMap; | |
use segment::json_path::JsonPath; | |
use segment::types::{MinShould, ValueVariants}; | |
use super::*; | |
use crate::rbac::{CollectionAccess, CollectionAccessMode}; | |
fn test_apply_filter() { | |
let list = CollectionAccessList(vec![CollectionAccess { | |
collection: "col".to_string(), | |
access: CollectionAccessMode::Read, | |
payload: Some(PayloadConstraint(HashMap::from([( | |
"field".parse().unwrap(), | |
ValueVariants::Integer(42), | |
)]))), | |
}]); | |
let mut filter = None; | |
list.find_view("col").unwrap().apply_filter(&mut filter); | |
assert_eq!( | |
filter, | |
Some(Filter { | |
must: Some(vec![Condition::Field(FieldCondition::new_match( | |
"field".parse().unwrap(), | |
Match::new_value(ValueVariants::Integer(42)) | |
))]), | |
..Default::default() | |
}) | |
); | |
let cond = |path: &str| Condition::IsNull(path.parse::<JsonPath>().unwrap().into()); | |
let mut filter = Some(Filter { | |
should: Some(vec![cond("a")]), | |
min_should: Some(MinShould { | |
conditions: vec![cond("b")], | |
min_count: 1, | |
}), | |
must: Some(vec![cond("c")]), | |
must_not: Some(vec![cond("d")]), | |
}); | |
list.find_view("col").unwrap().apply_filter(&mut filter); | |
assert_eq!( | |
filter, | |
Some(Filter { | |
should: Some(vec![cond("a")]), | |
min_should: Some(MinShould { | |
conditions: vec![cond("b")], | |
min_count: 1, | |
}), | |
must: Some(vec![ | |
cond("c"), | |
Condition::Field(FieldCondition::new_match( | |
"field".parse().unwrap(), | |
Match::new_value(ValueVariants::Integer(42)) | |
)) | |
]), | |
must_not: Some(vec![cond("d")]), | |
}) | |
); | |
} | |
} | |
mod tests_ops { | |
use std::fmt::Debug; | |
use api::rest::{ | |
self, LookupLocation, OrderByInterface, RecommendStrategy, SearchRequestInternal, | |
}; | |
use collection::operations::payload_ops::PayloadOpsDiscriminants; | |
use collection::operations::point_ops::{ | |
BatchPersisted, BatchVectorStructPersisted, PointInsertOperationsInternal, | |
PointInsertOperationsInternalDiscriminants, PointOperationsDiscriminants, | |
PointStructPersisted, PointSyncOperation, VectorStructPersisted, | |
}; | |
use collection::operations::query_enum::QueryEnum; | |
use collection::operations::types::UsingVector; | |
use collection::operations::vector_ops::{ | |
PointVectorsPersisted, UpdateVectorsOp, VectorOperationsDiscriminants, | |
}; | |
use collection::operations::{ | |
CollectionUpdateOperationsDiscriminants, CreateIndex, FieldIndexOperations, | |
FieldIndexOperationsDiscriminants, | |
}; | |
use segment::data_types::vectors::NamedVectorStruct; | |
use segment::types::{PointIdType, SearchParams, WithPayloadInterface, WithVector}; | |
use strum::IntoEnumIterator as _; | |
use super::*; | |
use crate::rbac::{AccessCollectionBuilder, GlobalAccessMode}; | |
/// Operation is allowed with the given access, and no rewrite is expected. | |
fn assert_allowed<Op: Debug + Clone + PartialEq + CheckableCollectionOperation>( | |
op: &Op, | |
access: &Access, | |
) { | |
let mut op_actual = op.clone(); | |
access | |
.check_point_op("col", &mut op_actual) | |
.expect("Should be allowed"); | |
assert_eq!(op, &op_actual, "Expected not to change"); | |
} | |
/// Operation is allowed with the given access, and the rewrite is expected. | |
/// A closure `rewrite` is expected to produce the same result as the rewritten operation. | |
fn assert_allowed_rewrite<Op: Debug + Clone + PartialEq + CheckableCollectionOperation>( | |
op: &Op, | |
access: &Access, | |
rewrite: impl FnOnce(&mut Op), | |
) { | |
let mut op_actual = op.clone(); | |
access | |
.check_point_op("col", &mut op_actual) | |
.expect("Should be allowed"); | |
let mut op_reference = op.clone(); | |
rewrite(&mut op_reference); | |
assert_eq!(op_reference, op_actual, "Expected to change"); | |
} | |
/// Operation is forbidden with the given access. | |
fn assert_forbidden<Op: Clone + CheckableCollectionOperation + PartialEq>( | |
op: &Op, | |
access: &Access, | |
) { | |
access | |
.check_point_op("col", &mut op.clone()) | |
.expect_err("Should be allowed"); | |
} | |
/// Operation requires write + whole collection access. | |
fn assert_requires_whole_write_access<Op>(op: &Op) | |
where | |
Op: CheckableCollectionOperation + Clone + Debug + PartialEq, | |
{ | |
assert_allowed(op, &Access::Global(GlobalAccessMode::Manage)); | |
assert_forbidden(op, &Access::Global(GlobalAccessMode::Read)); | |
assert_allowed( | |
op, | |
&AccessCollectionBuilder::new().add("col", true, true).into(), | |
); | |
assert_forbidden( | |
op, | |
&AccessCollectionBuilder::new() | |
.add("col", false, true) | |
.into(), | |
); | |
assert_forbidden( | |
op, | |
&AccessCollectionBuilder::new() | |
.add("col", true, false) | |
.into(), | |
); | |
} | |
fn test_recommend_request_internal() { | |
let op = RecommendRequestInternal { | |
positive: vec![RecommendExample::Dense(vec![0.0, 1.0, 2.0])], | |
negative: vec![RecommendExample::Sparse(vec![(0, 0.0)].try_into().unwrap())], | |
strategy: Some(RecommendStrategy::AverageVector), | |
filter: None, | |
params: Some(SearchParams::default()), | |
limit: 100, | |
offset: Some(100), | |
with_payload: Some(WithPayloadInterface::Bool(true)), | |
with_vector: Some(WithVector::Bool(true)), | |
score_threshold: Some(42.0), | |
using: Some(UsingVector::Name("vector".to_string())), | |
lookup_from: Some(LookupLocation { | |
collection: "col2".to_string(), | |
vector: Some("vector".to_string()), | |
shard_key: None, | |
}), | |
}; | |
assert_allowed(&op, &Access::Global(GlobalAccessMode::Manage)); | |
assert_allowed(&op, &Access::Global(GlobalAccessMode::Read)); | |
// Require whole access to col2 | |
assert_forbidden( | |
&op, | |
&AccessCollectionBuilder::new() | |
.add("col", false, false) | |
.add("col2", false, false) | |
.into(), | |
); | |
assert_allowed_rewrite( | |
&op, | |
&AccessCollectionBuilder::new() | |
.add("col", false, false) | |
.add("col2", false, true) | |
.into(), | |
|op| { | |
op.filter = Some(PayloadConstraint::new_test("col").to_filter()); | |
}, | |
); | |
// Point ID is used | |
assert_forbidden( | |
&RecommendRequestInternal { | |
positive: vec![RecommendExample::PointId(ExtendedPointId::NumId(12345))], | |
..op.clone() | |
}, | |
&AccessCollectionBuilder::new() | |
.add("col", false, false) | |
.add("col2", false, true) | |
.into(), | |
); | |
// lookup_from requires read access | |
assert_forbidden( | |
&op, | |
&AccessCollectionBuilder::new() | |
.add("col", false, true) | |
.into(), | |
); | |
assert_allowed( | |
&RecommendRequestInternal { | |
lookup_from: None, | |
..op | |
}, | |
&AccessCollectionBuilder::new() | |
.add("col", false, true) | |
.into(), | |
); | |
} | |
fn test_point_request_internal() { | |
let op = PointRequestInternal { | |
ids: vec![PointIdType::NumId(12345)], | |
with_payload: None, | |
with_vector: WithVector::Bool(true), | |
}; | |
assert_allowed(&op, &Access::Global(GlobalAccessMode::Manage)); | |
assert_allowed(&op, &Access::Global(GlobalAccessMode::Read)); | |
assert_forbidden( | |
&op, | |
&AccessCollectionBuilder::new() | |
.add("col", false, false) | |
.into(), | |
); | |
assert_allowed( | |
&op, | |
&AccessCollectionBuilder::new() | |
.add("col", false, true) | |
.into(), | |
); | |
} | |
fn test_core_search_request() { | |
let op = CoreSearchRequest { | |
query: QueryEnum::Nearest(NamedVectorStruct::Default(vec![0.0, 1.0, 2.0])), | |
filter: None, | |
params: Some(SearchParams::default()), | |
limit: 100, | |
offset: 100, | |
with_payload: Some(WithPayloadInterface::Bool(true)), | |
with_vector: Some(WithVector::Bool(true)), | |
score_threshold: Some(42.0), | |
}; | |
assert_allowed(&op, &Access::Global(GlobalAccessMode::Manage)); | |
assert_allowed(&op, &Access::Global(GlobalAccessMode::Read)); | |
assert_allowed( | |
&op, | |
&AccessCollectionBuilder::new() | |
.add("col", false, true) | |
.into(), | |
); | |
assert_allowed_rewrite( | |
&op, | |
&AccessCollectionBuilder::new() | |
.add("col", false, false) | |
.into(), | |
|op| { | |
op.filter = Some(PayloadConstraint::new_test("col").to_filter()); | |
}, | |
); | |
} | |
fn test_count_request_internal() { | |
let op = CountRequestInternal { | |
filter: None, | |
exact: false, | |
}; | |
assert_allowed(&op, &Access::Global(GlobalAccessMode::Manage)); | |
assert_allowed(&op, &Access::Global(GlobalAccessMode::Read)); | |
assert_allowed( | |
&op, | |
&AccessCollectionBuilder::new() | |
.add("col", false, true) | |
.into(), | |
); | |
assert_allowed_rewrite( | |
&op, | |
&AccessCollectionBuilder::new() | |
.add("col", false, false) | |
.into(), | |
|op| { | |
op.filter = Some(PayloadConstraint::new_test("col").to_filter()); | |
}, | |
); | |
} | |
fn test_group_request_source() { | |
let op = GroupRequest { | |
// NOTE: SourceRequest::Recommend is already tested in test_recommend_request_internal | |
source: SourceRequest::Search(SearchRequestInternal { | |
vector: rest::NamedVectorStruct::Default(vec![0.0, 1.0, 2.0]), | |
filter: None, | |
params: Some(SearchParams::default()), | |
limit: 100, | |
offset: Some(100), | |
with_payload: Some(WithPayloadInterface::Bool(true)), | |
with_vector: Some(WithVector::Bool(true)), | |
score_threshold: Some(42.0), | |
}), | |
group_by: "path".parse().unwrap(), | |
group_size: 100, | |
limit: 100, | |
with_lookup: Some(WithLookup { | |
collection_name: "col2".to_string(), | |
with_payload: Some(WithPayloadInterface::Bool(true)), | |
with_vectors: Some(WithVector::Bool(true)), | |
}), | |
}; | |
assert_allowed(&op, &Access::Global(GlobalAccessMode::Manage)); | |
assert_allowed(&op, &Access::Global(GlobalAccessMode::Read)); | |
assert_allowed( | |
&op, | |
&AccessCollectionBuilder::new() | |
.add("col", false, true) | |
.add("col2", false, true) | |
.into(), | |
); | |
// with_lookup requires whole read access | |
assert_forbidden( | |
&op, | |
&AccessCollectionBuilder::new() | |
.add("col", false, true) | |
.into(), | |
); | |
assert_forbidden( | |
&op, | |
&AccessCollectionBuilder::new() | |
.add("col", false, true) | |
.add("col", false, false) | |
.into(), | |
); | |
assert_allowed( | |
&GroupRequest { | |
with_lookup: None, | |
..op.clone() | |
}, | |
&AccessCollectionBuilder::new() | |
.add("col", false, true) | |
.into(), | |
); | |
// filter rewrite | |
assert_allowed_rewrite( | |
&op, | |
&AccessCollectionBuilder::new() | |
.add("col", false, false) | |
.add("col2", false, true) | |
.into(), | |
|op| match &mut op.source { | |
SourceRequest::Search(s) => { | |
s.filter = Some(PayloadConstraint::new_test("col").to_filter()); | |
} | |
SourceRequest::Recommend(_) => unreachable!(), | |
SourceRequest::Query(_) => unreachable!(), | |
}, | |
); | |
} | |
fn test_discover_request_internal() { | |
let op = DiscoverRequestInternal { | |
target: Some(RecommendExample::Dense(vec![0.0, 1.0, 2.0])), | |
context: Some(vec![ContextExamplePair { | |
positive: RecommendExample::Dense(vec![0.0, 1.0, 2.0]), | |
negative: RecommendExample::Dense(vec![0.0, 1.0, 2.0]), | |
}]), | |
filter: None, | |
params: Some(SearchParams::default()), | |
limit: 100, | |
offset: Some(100), | |
with_payload: Some(WithPayloadInterface::Bool(true)), | |
with_vector: Some(WithVector::Bool(true)), | |
using: Some(UsingVector::Name("vector".to_string())), | |
lookup_from: Some(LookupLocation { | |
collection: "col2".to_string(), | |
vector: Some("vector".to_string()), | |
shard_key: None, | |
}), | |
}; | |
assert_allowed(&op, &Access::Global(GlobalAccessMode::Manage)); | |
assert_allowed(&op, &Access::Global(GlobalAccessMode::Read)); | |
assert_allowed( | |
&op, | |
&AccessCollectionBuilder::new() | |
.add("col", false, true) | |
.add("col2", false, true) | |
.into(), | |
); | |
assert_allowed_rewrite( | |
&op, | |
&AccessCollectionBuilder::new() | |
.add("col", false, false) | |
.add("col2", false, true) | |
.into(), | |
|op| { | |
op.filter = Some(PayloadConstraint::new_test("col").to_filter()); | |
}, | |
); | |
// Point ID is used | |
assert_forbidden( | |
&DiscoverRequestInternal { | |
target: Some(RecommendExample::PointId(ExtendedPointId::NumId(12345))), | |
..op.clone() | |
}, | |
&AccessCollectionBuilder::new() | |
.add("col", false, false) | |
.add("col2", false, true) | |
.into(), | |
); | |
assert_forbidden( | |
&DiscoverRequestInternal { | |
context: Some(vec![ContextExamplePair { | |
positive: RecommendExample::PointId(ExtendedPointId::NumId(12345)), | |
negative: RecommendExample::Dense(vec![0.0, 1.0, 2.0]), | |
}]), | |
..op.clone() | |
}, | |
&AccessCollectionBuilder::new() | |
.add("col", false, false) | |
.add("col2", false, true) | |
.into(), | |
); | |
assert_forbidden( | |
&DiscoverRequestInternal { | |
context: Some(vec![ContextExamplePair { | |
positive: RecommendExample::Dense(vec![0.0, 1.0, 2.0]), | |
negative: RecommendExample::PointId(ExtendedPointId::NumId(12345)), | |
}]), | |
..op.clone() | |
}, | |
&AccessCollectionBuilder::new() | |
.add("col", false, false) | |
.add("col2", false, true) | |
.into(), | |
); | |
// lookup_from requires read access | |
assert_forbidden( | |
&op, | |
&AccessCollectionBuilder::new() | |
.add("col", false, true) | |
.into(), | |
); | |
assert_allowed( | |
&DiscoverRequestInternal { | |
lookup_from: None, | |
..op | |
}, | |
&AccessCollectionBuilder::new() | |
.add("col", false, true) | |
.into(), | |
); | |
} | |
fn test_scroll_request_internal() { | |
let op = ScrollRequestInternal { | |
offset: Some(ExtendedPointId::NumId(12345)), | |
limit: Some(100), | |
filter: None, | |
with_payload: Some(WithPayloadInterface::Bool(true)), | |
with_vector: WithVector::Bool(true), | |
order_by: Some(OrderByInterface::Key("path".parse().unwrap())), | |
}; | |
assert_allowed(&op, &Access::Global(GlobalAccessMode::Manage)); | |
assert_allowed(&op, &Access::Global(GlobalAccessMode::Read)); | |
assert_allowed( | |
&op, | |
&AccessCollectionBuilder::new() | |
.add("col", false, true) | |
.into(), | |
); | |
assert_allowed_rewrite( | |
&ScrollRequestInternal { ..op }, | |
&AccessCollectionBuilder::new() | |
.add("col", false, false) | |
.into(), | |
|op| { | |
op.filter = Some(PayloadConstraint::new_test("col").to_filter()); | |
}, | |
); | |
} | |
fn test_collection_update_operations() { | |
CollectionUpdateOperationsDiscriminants::iter().for_each(|discr| match discr { | |
CollectionUpdateOperationsDiscriminants::PointOperation => { | |
check_collection_update_operations_points() | |
} | |
CollectionUpdateOperationsDiscriminants::VectorOperation => { | |
check_collection_update_operations_update_vectors() | |
} | |
CollectionUpdateOperationsDiscriminants::PayloadOperation => { | |
check_collection_update_operations_payload() | |
} | |
CollectionUpdateOperationsDiscriminants::FieldIndexOperation => { | |
check_collection_update_operations_field_index() | |
} | |
}); | |
} | |
/// Tests for [`CollectionUpdateOperations::PointOperation`]. | |
fn check_collection_update_operations_points() { | |
PointOperationsDiscriminants::iter().for_each(|discr| match discr { | |
PointOperationsDiscriminants::UpsertPoints => { | |
for discr in PointInsertOperationsInternalDiscriminants::iter() { | |
let inner = match discr { | |
PointInsertOperationsInternalDiscriminants::PointsBatch => { | |
PointInsertOperationsInternal::PointsBatch(BatchPersisted { | |
ids: vec![ExtendedPointId::NumId(12345)], | |
vectors: BatchVectorStructPersisted::Single(vec![vec![ | |
0.0, 1.0, 2.0, | |
]]), | |
payloads: None, | |
}) | |
} | |
PointInsertOperationsInternalDiscriminants::PointsList => { | |
PointInsertOperationsInternal::PointsList(vec![PointStructPersisted { | |
id: ExtendedPointId::NumId(12345), | |
vector: VectorStructPersisted::Single(vec![0.0, 1.0, 2.0]), | |
payload: None, | |
}]) | |
} | |
}; | |
let op = CollectionUpdateOperations::PointOperation( | |
PointOperations::UpsertPoints(inner), | |
); | |
assert_requires_whole_write_access(&op); | |
} | |
} | |
PointOperationsDiscriminants::DeletePoints => { | |
let op = | |
CollectionUpdateOperations::PointOperation(PointOperations::DeletePoints { | |
ids: vec![ExtendedPointId::NumId(12345)], | |
}); | |
check_collection_update_operations_delete_points(&op); | |
} | |
PointOperationsDiscriminants::DeletePointsByFilter => { | |
let op = CollectionUpdateOperations::PointOperation( | |
PointOperations::DeletePointsByFilter(make_filter_from_ids(vec![ | |
ExtendedPointId::NumId(12345), | |
])), | |
); | |
check_collection_update_operations_delete_points(&op); | |
} | |
PointOperationsDiscriminants::SyncPoints => { | |
let op = CollectionUpdateOperations::PointOperation(PointOperations::SyncPoints( | |
PointSyncOperation { | |
from_id: None, | |
to_id: None, | |
points: Vec::new(), | |
}, | |
)); | |
assert_requires_whole_write_access(&op); | |
} | |
}); | |
} | |
/// Tests for [`CollectionUpdateOperations::PointOperation`] with | |
/// [`PointOperations::DeletePoints`] and [`PointOperations::DeletePointsByFilter`]. | |
fn check_collection_update_operations_delete_points(op: &CollectionUpdateOperations) { | |
assert_allowed(op, &Access::Global(GlobalAccessMode::Manage)); | |
assert_forbidden(op, &Access::Global(GlobalAccessMode::Read)); | |
assert_allowed( | |
op, | |
&AccessCollectionBuilder::new().add("col", true, true).into(), | |
); | |
assert_forbidden( | |
op, | |
&AccessCollectionBuilder::new() | |
.add("col", false, true) | |
.into(), | |
); | |
assert_allowed_rewrite( | |
op, | |
&AccessCollectionBuilder::new() | |
.add("col", true, false) | |
.into(), | |
|op| { | |
*op = CollectionUpdateOperations::PointOperation( | |
PointOperations::DeletePointsByFilter( | |
make_filter_from_ids(vec![ExtendedPointId::NumId(12345)]) | |
.merge_owned(PayloadConstraint::new_test("col").to_filter()), | |
), | |
); | |
}, | |
); | |
} | |
/// Tests for [`CollectionUpdateOperations::VectorOperation`]. | |
fn check_collection_update_operations_update_vectors() { | |
VectorOperationsDiscriminants::iter().for_each(|discr| match discr { | |
VectorOperationsDiscriminants::UpdateVectors => { | |
let op = CollectionUpdateOperations::VectorOperation( | |
VectorOperations::UpdateVectors(UpdateVectorsOp { | |
points: vec![PointVectorsPersisted { | |
id: ExtendedPointId::NumId(12345), | |
vector: VectorStructPersisted::Single(vec![0.0, 1.0, 2.0]), | |
}], | |
}), | |
); | |
assert_requires_whole_write_access(&op); | |
} | |
VectorOperationsDiscriminants::DeleteVectors => { | |
let op = | |
CollectionUpdateOperations::VectorOperation(VectorOperations::DeleteVectors( | |
PointIdsList { | |
points: vec![ExtendedPointId::NumId(12345)], | |
shard_key: None, | |
}, | |
vec!["vector".to_string()], | |
)); | |
check_collection_update_operations_delete_vectors(&op); | |
} | |
VectorOperationsDiscriminants::DeleteVectorsByFilter => { | |
let op = CollectionUpdateOperations::VectorOperation( | |
VectorOperations::DeleteVectorsByFilter( | |
make_filter_from_ids(vec![ExtendedPointId::NumId(12345)]), | |
vec!["vector".to_string()], | |
), | |
); | |
check_collection_update_operations_delete_vectors(&op); | |
} | |
}); | |
} | |
/// Tests for [`CollectionUpdateOperations::VectorOperation`] with | |
/// [`VectorOperations::DeleteVectors`] and [`VectorOperations::DeleteVectorsByFilter`]. | |
fn check_collection_update_operations_delete_vectors(op: &CollectionUpdateOperations) { | |
assert_allowed(op, &Access::Global(GlobalAccessMode::Manage)); | |
assert_forbidden(op, &Access::Global(GlobalAccessMode::Read)); | |
assert_allowed( | |
op, | |
&AccessCollectionBuilder::new().add("col", true, true).into(), | |
); | |
assert_forbidden( | |
op, | |
&AccessCollectionBuilder::new() | |
.add("col", false, true) | |
.into(), | |
); | |
assert_allowed_rewrite( | |
op, | |
&AccessCollectionBuilder::new() | |
.add("col", true, false) | |
.into(), | |
|op| { | |
*op = CollectionUpdateOperations::VectorOperation( | |
VectorOperations::DeleteVectorsByFilter( | |
make_filter_from_ids(vec![ExtendedPointId::NumId(12345)]) | |
.merge_owned(PayloadConstraint::new_test("col").to_filter()), | |
vec!["vector".to_string()], | |
), | |
); | |
}, | |
); | |
} | |
/// Tests for [`CollectionUpdateOperations::PayloadOperation`]. | |
fn check_collection_update_operations_payload() { | |
for discr in PayloadOpsDiscriminants::iter() { | |
let inner = match discr { | |
PayloadOpsDiscriminants::SetPayload => PayloadOps::SetPayload(SetPayloadOp { | |
payload: Payload::default(), | |
points: Some(vec![ExtendedPointId::NumId(12345)]), | |
filter: None, | |
key: None, | |
}), | |
PayloadOpsDiscriminants::DeletePayload => { | |
PayloadOps::DeletePayload(DeletePayloadOp { | |
keys: vec!["path".parse().unwrap()], | |
points: Some(vec![ExtendedPointId::NumId(12345)]), | |
filter: None, | |
}) | |
} | |
PayloadOpsDiscriminants::ClearPayload => PayloadOps::ClearPayload { | |
points: vec![ExtendedPointId::NumId(12345)], | |
}, | |
PayloadOpsDiscriminants::ClearPayloadByFilter => { | |
PayloadOps::ClearPayloadByFilter(make_filter_from_ids(vec![ | |
ExtendedPointId::NumId(12345), | |
])) | |
} | |
PayloadOpsDiscriminants::OverwritePayload => { | |
PayloadOps::OverwritePayload(SetPayloadOp { | |
payload: Payload::default(), | |
points: Some(vec![ExtendedPointId::NumId(12345)]), | |
filter: None, | |
key: None, | |
}) | |
} | |
}; | |
let op = CollectionUpdateOperations::PayloadOperation(inner); | |
assert_requires_whole_write_access(&op); | |
} | |
} | |
/// Tests for [`CollectionUpdateOperations::FieldIndexOperation`]. | |
fn check_collection_update_operations_field_index() { | |
for discr in FieldIndexOperationsDiscriminants::iter() { | |
let inner = match discr { | |
FieldIndexOperationsDiscriminants::CreateIndex => { | |
FieldIndexOperations::CreateIndex(CreateIndex { | |
field_name: "path".parse().unwrap(), | |
field_schema: None, | |
}) | |
} | |
FieldIndexOperationsDiscriminants::DeleteIndex => { | |
FieldIndexOperations::DeleteIndex("path".parse().unwrap()) | |
} | |
}; | |
let op = CollectionUpdateOperations::FieldIndexOperation(inner); | |
assert_allowed(&op, &Access::Global(GlobalAccessMode::Manage)); | |
assert_forbidden(&op, &Access::Global(GlobalAccessMode::Read)); | |
assert_forbidden( | |
&op, | |
&AccessCollectionBuilder::new().add("col", true, true).into(), | |
); | |
assert_forbidden( | |
&op, | |
&AccessCollectionBuilder::new() | |
.add("col", false, true) | |
.into(), | |
); | |
} | |
} | |
} | |