use std::borrow::Cow; use std::collections::HashSet; use std::mem::take; use api::rest::LookupLocation; use collection::collection::distance_matrix::CollectionSearchMatrixRequest; use collection::grouping::group_by::{GroupRequest, SourceRequest}; use collection::lookup::WithLookup; use collection::operations::payload_ops::{DeletePayloadOp, PayloadOps, SetPayloadOp}; use collection::operations::point_ops::{PointIdsList, PointOperations}; use collection::operations::types::{ ContextExamplePair, CoreSearchRequest, CountRequestInternal, DiscoverRequestInternal, PointRequestInternal, RecommendExample, RecommendRequestInternal, ScrollRequestInternal, }; use collection::operations::universal_query::collection_query::{ CollectionPrefetch, CollectionQueryRequest, Query, VectorInputInternal, VectorQuery, }; use collection::operations::vector_ops::VectorOperations; use collection::operations::CollectionUpdateOperations; use segment::data_types::facets::FacetParams; use segment::types::{Condition, ExtendedPointId, FieldCondition, Filter, Match, Payload}; use super::{ incompatible_with_payload_constraint, Access, AccessRequirements, CollectionAccessList, CollectionAccessView, CollectionPass, PayloadConstraint, }; use crate::content_manager::collection_meta_ops::CollectionMetaOperations; use crate::content_manager::errors::{StorageError, StorageResult}; impl Access { #[allow(private_bounds)] pub(crate) fn check_point_op<'a>( &self, collection_name: &'a str, op: &mut impl CheckableCollectionOperation, ) -> Result, StorageError> { let requirements = op.access_requirements(); match self { Access::Global(mode) => mode.meets_requirements(requirements)?, Access::Collection(list) => { let view = list.find_view(collection_name)?; view.meets_requirements(requirements)?; op.check_access(view, list)?; } } Ok(CollectionPass(Cow::Borrowed(collection_name))) } pub(crate) fn check_collection_meta_operation( &self, operation: &CollectionMetaOperations, ) -> Result<(), StorageError> { match operation { CollectionMetaOperations::CreateCollection(_) | CollectionMetaOperations::UpdateCollection(_) | CollectionMetaOperations::DeleteCollection(_) | CollectionMetaOperations::ChangeAliases(_) | CollectionMetaOperations::Resharding(_, _) | CollectionMetaOperations::TransferShard(_, _) | CollectionMetaOperations::SetShardReplicaState(_) | CollectionMetaOperations::CreateShardKey(_) | CollectionMetaOperations::DropShardKey(_) => { self.check_global_access(AccessRequirements::new().manage())?; } CollectionMetaOperations::CreatePayloadIndex(op) => { self.check_collection_access( &op.collection_name, AccessRequirements::new().write().whole(), )?; } CollectionMetaOperations::DropPayloadIndex(op) => { self.check_collection_access( &op.collection_name, AccessRequirements::new().write().whole(), )?; } CollectionMetaOperations::Nop { token: _ } => (), } Ok(()) } } trait CheckableCollectionOperation { /// Used to distinguish whether the operation is read-only or read-write. fn access_requirements(&self) -> AccessRequirements; fn check_access( &mut self, view: CollectionAccessView<'_>, access: &CollectionAccessList, ) -> Result<(), StorageError>; } impl CollectionAccessList { fn check_lookup_from( &self, lookup_location: &Option, ) -> Result<(), StorageError> { if let Some(lookup_location) = lookup_location { self.find_view(&lookup_location.collection)? .check_whole_access()?; } Ok(()) } fn check_with_lookup(&self, with_lookup: &Option) -> Result<(), StorageError> { if let Some(with_lookup) = with_lookup { self.find_view(&with_lookup.collection_name)? .check_whole_access()?; } Ok(()) } } impl<'a> CollectionAccessView<'a> { fn apply_filter(&self, filter: &mut Option) { if let Some(payload) = &self.payload { let f = filter.get_or_insert_with(Default::default); *f = take(f).merge_owned(payload.to_filter()); } } fn check_recommend_example(&self, example: &RecommendExample) -> Result<(), StorageError> { match example { RecommendExample::PointId(_) => self.check_whole_access(), RecommendExample::Dense(_) | RecommendExample::Sparse(_) => Ok(()), } } fn check_vector_query( &self, vector_query: &VectorQuery, ) -> Result<(), StorageError> { match vector_query { VectorQuery::Nearest(nearest) => self.check_vector_input(nearest)?, VectorQuery::RecommendBestScore(reco) | VectorQuery::RecommendAverageVector(reco) => { for vector_input in reco.flat_iter() { self.check_vector_input(vector_input)? } } VectorQuery::Discover(discover) => { for vector_input in discover.flat_iter() { self.check_vector_input(vector_input)? } } VectorQuery::Context(context) => { for vector_input in context.flat_iter() { self.check_vector_input(vector_input)? } } }; Ok(()) } fn check_vector_input(&self, vector_input: &VectorInputInternal) -> Result<(), StorageError> { match vector_input { VectorInputInternal::Vector(_) => Ok(()), VectorInputInternal::Id(_) => self.check_whole_access(), } } } impl CheckableCollectionOperation for RecommendRequestInternal { fn access_requirements(&self) -> AccessRequirements { AccessRequirements { write: false, manage: false, whole: false, } } fn check_access( &mut self, view: CollectionAccessView<'_>, access: &CollectionAccessList, ) -> Result<(), StorageError> { for e in &self.positive { view.check_recommend_example(e)?; } for e in &self.negative { view.check_recommend_example(e)?; } access.check_lookup_from(&self.lookup_from)?; view.apply_filter(&mut self.filter); Ok(()) } } impl CheckableCollectionOperation for PointRequestInternal { fn access_requirements(&self) -> AccessRequirements { AccessRequirements { write: false, manage: false, whole: true, } } fn check_access( &mut self, _view: CollectionAccessView<'_>, _access: &CollectionAccessList, ) -> Result<(), StorageError> { Ok(()) } } impl CheckableCollectionOperation for CoreSearchRequest { fn access_requirements(&self) -> AccessRequirements { AccessRequirements { write: false, manage: false, whole: false, } } fn check_access( &mut self, view: CollectionAccessView<'_>, _access: &CollectionAccessList, ) -> Result<(), StorageError> { view.apply_filter(&mut self.filter); Ok(()) } } impl CheckableCollectionOperation for CountRequestInternal { fn access_requirements(&self) -> AccessRequirements { AccessRequirements { write: false, manage: false, whole: false, } } fn check_access( &mut self, view: CollectionAccessView<'_>, _access: &CollectionAccessList, ) -> Result<(), StorageError> { view.apply_filter(&mut self.filter); Ok(()) } } impl CheckableCollectionOperation for GroupRequest { fn access_requirements(&self) -> AccessRequirements { AccessRequirements { write: false, manage: false, whole: false, } } fn check_access( &mut self, view: CollectionAccessView<'_>, access: &CollectionAccessList, ) -> Result<(), StorageError> { match &mut self.source { SourceRequest::Search(s) => { view.apply_filter(&mut s.filter); } SourceRequest::Recommend(r) => r.check_access(view, access)?, SourceRequest::Query(q) => q.check_access(view, access)?, } access.check_with_lookup(&self.with_lookup)?; Ok(()) } } impl CheckableCollectionOperation for DiscoverRequestInternal { fn access_requirements(&self) -> AccessRequirements { AccessRequirements { write: false, manage: false, whole: false, } } fn check_access( &mut self, view: CollectionAccessView<'_>, access: &CollectionAccessList, ) -> Result<(), StorageError> { if let Some(target) = &self.target { view.check_recommend_example(target)?; } for ContextExamplePair { positive, negative } in self.context.iter().flat_map(|c| c.iter()) { view.check_recommend_example(positive)?; view.check_recommend_example(negative)?; } view.apply_filter(&mut self.filter); access.check_lookup_from(&self.lookup_from)?; Ok(()) } } impl CheckableCollectionOperation for ScrollRequestInternal { fn access_requirements(&self) -> AccessRequirements { AccessRequirements { write: false, manage: false, whole: false, } } fn check_access( &mut self, view: CollectionAccessView<'_>, _access: &CollectionAccessList, ) -> Result<(), StorageError> { view.apply_filter(&mut self.filter); Ok(()) } } impl CheckableCollectionOperation for CollectionQueryRequest { fn access_requirements(&self) -> AccessRequirements { AccessRequirements { write: false, manage: false, whole: false, } } fn check_access( &mut self, view: CollectionAccessView<'_>, access: &CollectionAccessList, ) -> Result<(), StorageError> { view.apply_filter(&mut self.filter); if let Some(Query::Vector(vector_query)) = &self.query { view.check_vector_query(vector_query)? } access.check_lookup_from(&self.lookup_from)?; for prefetch_query in self.prefetch.iter_mut() { check_access_for_prefetch(prefetch_query, &view, access)?; } Ok(()) } } fn check_access_for_prefetch( prefetch: &mut CollectionPrefetch, view: &CollectionAccessView<'_>, access: &CollectionAccessList, ) -> Result<(), StorageError> { view.apply_filter(&mut prefetch.filter); if let Some(Query::Vector(vector_query)) = &prefetch.query { view.check_vector_query(vector_query)? } access.check_lookup_from(&prefetch.lookup_from)?; // Recurse inner prefetches for prefetch_query in prefetch.prefetch.iter_mut() { check_access_for_prefetch(prefetch_query, view, access)?; } Ok(()) } impl CheckableCollectionOperation for FacetParams { fn access_requirements(&self) -> AccessRequirements { AccessRequirements { write: false, manage: false, whole: false, } } fn check_access( &mut self, view: CollectionAccessView<'_>, _access: &CollectionAccessList, ) -> StorageResult<()> { view.apply_filter(&mut self.filter); Ok(()) } } impl CheckableCollectionOperation for CollectionSearchMatrixRequest { fn access_requirements(&self) -> AccessRequirements { AccessRequirements { write: false, manage: false, whole: false, } } fn check_access( &mut self, view: CollectionAccessView<'_>, _access: &CollectionAccessList, ) -> StorageResult<()> { view.apply_filter(&mut self.filter); Ok(()) } } impl CheckableCollectionOperation for CollectionUpdateOperations { fn access_requirements(&self) -> AccessRequirements { match self { CollectionUpdateOperations::PointOperation(_) | CollectionUpdateOperations::VectorOperation(_) | CollectionUpdateOperations::PayloadOperation(_) => AccessRequirements { write: true, manage: false, whole: false, // Checked in `check_access()` }, CollectionUpdateOperations::FieldIndexOperation(_) => AccessRequirements { write: true, manage: true, whole: true, }, } } fn check_access( &mut self, view: CollectionAccessView<'_>, _access: &CollectionAccessList, ) -> Result<(), StorageError> { match self { CollectionUpdateOperations::PointOperation(op) => match op { PointOperations::UpsertPoints(_) => { view.check_whole_access()?; } PointOperations::DeletePoints { ids } => { if let Some(payload) = &view.payload { *op = PointOperations::DeletePointsByFilter( make_filter_from_ids(take(ids)).merge_owned(payload.to_filter()), ); } } PointOperations::DeletePointsByFilter(filter) => { if let Some(payload) = &view.payload { *filter = take(filter).merge_owned(payload.to_filter()); } } PointOperations::SyncPoints(_) => { view.check_whole_access()?; } }, CollectionUpdateOperations::VectorOperation(op) => match op { VectorOperations::UpdateVectors(_) => { view.check_whole_access()?; } VectorOperations::DeleteVectors(PointIdsList { points, shard_key }, vectors) => { if let Some(payload) = &view.payload { if shard_key.is_some() { // It is unclear where to put the shard_key return incompatible_with_payload_constraint(view.collection); } *op = VectorOperations::DeleteVectorsByFilter( make_filter_from_ids(take(points)).merge_owned(payload.to_filter()), take(vectors), ); } } VectorOperations::DeleteVectorsByFilter(filter, _) => { if let Some(payload) = &view.payload { *filter = take(filter).merge_owned(payload.to_filter()); } } }, CollectionUpdateOperations::PayloadOperation(op) => 'a: { let Some(payload) = &view.payload else { // Allow all operations when there is no payload constraint break 'a; }; match op { PayloadOps::SetPayload(SetPayloadOp { payload: _, // TODO: validate points, filter, key: _, // TODO: validate }) => { let filter = filter.get_or_insert_with(Default::default); if let Some(points) = take(points) { *filter = take(filter).merge_owned(make_filter_from_ids(points)); } // Reject as not implemented return incompatible_with_payload_constraint(view.collection); } PayloadOps::DeletePayload(DeletePayloadOp { keys: _, // TODO: validate points, filter, }) => { let filter = filter.get_or_insert_with(Default::default); if let Some(points) = take(points) { *filter = take(filter).merge_owned(make_filter_from_ids(points)); } // Reject as not implemented return incompatible_with_payload_constraint(view.collection); } PayloadOps::ClearPayload { points } => { *op = PayloadOps::OverwritePayload(SetPayloadOp { payload: payload.make_payload(view.collection)?, points: None, filter: Some( make_filter_from_ids(take(points)).merge_owned(payload.to_filter()), ), key: None, }); } PayloadOps::ClearPayloadByFilter(filter) => { *op = PayloadOps::OverwritePayload(SetPayloadOp { payload: payload.make_payload(view.collection)?, points: None, filter: Some(take(filter).merge_owned(payload.to_filter())), key: None, }); } PayloadOps::OverwritePayload(SetPayloadOp { payload: _, // TODO: validate points, filter, key: _, // TODO: validate }) => { let filter = filter.get_or_insert_with(Default::default); if let Some(points) = take(points) { *filter = take(filter).merge_owned(make_filter_from_ids(points)); } // Reject as not implemented return incompatible_with_payload_constraint(view.collection); } } } CollectionUpdateOperations::FieldIndexOperation(_) => (), } Ok(()) } } /// Create a `must` filter from a list of point IDs. fn make_filter_from_ids(ids: Vec) -> Filter { let cond = ids.into_iter().collect::>().into(); Filter { must: Some(vec![Condition::HasId(cond)]), ..Default::default() } } impl PayloadConstraint { /// Create a `must` filter. fn to_filter(&self) -> Filter { Filter { must: Some( self.0 .iter() .map(|(path, value)| { Condition::Field(FieldCondition::new_match( path.clone(), Match::new_value(value.clone()), )) }) .collect(), ), ..Default::default() } } fn make_payload(&self, collection_name: &str) -> Result { let _ = self; // TODO: We need to construct a payload, then validate it against the claim incompatible_with_payload_constraint(collection_name) // Reject as not implemented } } #[cfg(test)] mod tests { use std::collections::HashMap; use segment::json_path::JsonPath; use segment::types::{MinShould, ValueVariants}; use super::*; use crate::rbac::{CollectionAccess, CollectionAccessMode}; #[test] fn test_apply_filter() { let list = CollectionAccessList(vec![CollectionAccess { collection: "col".to_string(), access: CollectionAccessMode::Read, payload: Some(PayloadConstraint(HashMap::from([( "field".parse().unwrap(), ValueVariants::Integer(42), )]))), }]); let mut filter = None; list.find_view("col").unwrap().apply_filter(&mut filter); assert_eq!( filter, Some(Filter { must: Some(vec![Condition::Field(FieldCondition::new_match( "field".parse().unwrap(), Match::new_value(ValueVariants::Integer(42)) ))]), ..Default::default() }) ); let cond = |path: &str| Condition::IsNull(path.parse::().unwrap().into()); let mut filter = Some(Filter { should: Some(vec![cond("a")]), min_should: Some(MinShould { conditions: vec![cond("b")], min_count: 1, }), must: Some(vec![cond("c")]), must_not: Some(vec![cond("d")]), }); list.find_view("col").unwrap().apply_filter(&mut filter); assert_eq!( filter, Some(Filter { should: Some(vec![cond("a")]), min_should: Some(MinShould { conditions: vec![cond("b")], min_count: 1, }), must: Some(vec![ cond("c"), Condition::Field(FieldCondition::new_match( "field".parse().unwrap(), Match::new_value(ValueVariants::Integer(42)) )) ]), must_not: Some(vec![cond("d")]), }) ); } } #[cfg(test)] mod tests_ops { use std::fmt::Debug; use api::rest::{ self, LookupLocation, OrderByInterface, RecommendStrategy, SearchRequestInternal, }; use collection::operations::payload_ops::PayloadOpsDiscriminants; use collection::operations::point_ops::{ BatchPersisted, BatchVectorStructPersisted, PointInsertOperationsInternal, PointInsertOperationsInternalDiscriminants, PointOperationsDiscriminants, PointStructPersisted, PointSyncOperation, VectorStructPersisted, }; use collection::operations::query_enum::QueryEnum; use collection::operations::types::UsingVector; use collection::operations::vector_ops::{ PointVectorsPersisted, UpdateVectorsOp, VectorOperationsDiscriminants, }; use collection::operations::{ CollectionUpdateOperationsDiscriminants, CreateIndex, FieldIndexOperations, FieldIndexOperationsDiscriminants, }; use segment::data_types::vectors::NamedVectorStruct; use segment::types::{PointIdType, SearchParams, WithPayloadInterface, WithVector}; use strum::IntoEnumIterator as _; use super::*; use crate::rbac::{AccessCollectionBuilder, GlobalAccessMode}; /// Operation is allowed with the given access, and no rewrite is expected. fn assert_allowed( op: &Op, access: &Access, ) { let mut op_actual = op.clone(); access .check_point_op("col", &mut op_actual) .expect("Should be allowed"); assert_eq!(op, &op_actual, "Expected not to change"); } /// Operation is allowed with the given access, and the rewrite is expected. /// A closure `rewrite` is expected to produce the same result as the rewritten operation. fn assert_allowed_rewrite( op: &Op, access: &Access, rewrite: impl FnOnce(&mut Op), ) { let mut op_actual = op.clone(); access .check_point_op("col", &mut op_actual) .expect("Should be allowed"); let mut op_reference = op.clone(); rewrite(&mut op_reference); assert_eq!(op_reference, op_actual, "Expected to change"); } /// Operation is forbidden with the given access. fn assert_forbidden( op: &Op, access: &Access, ) { access .check_point_op("col", &mut op.clone()) .expect_err("Should be allowed"); } /// Operation requires write + whole collection access. fn assert_requires_whole_write_access(op: &Op) where Op: CheckableCollectionOperation + Clone + Debug + PartialEq, { assert_allowed(op, &Access::Global(GlobalAccessMode::Manage)); assert_forbidden(op, &Access::Global(GlobalAccessMode::Read)); assert_allowed( op, &AccessCollectionBuilder::new().add("col", true, true).into(), ); assert_forbidden( op, &AccessCollectionBuilder::new() .add("col", false, true) .into(), ); assert_forbidden( op, &AccessCollectionBuilder::new() .add("col", true, false) .into(), ); } #[test] fn test_recommend_request_internal() { let op = RecommendRequestInternal { positive: vec![RecommendExample::Dense(vec![0.0, 1.0, 2.0])], negative: vec![RecommendExample::Sparse(vec![(0, 0.0)].try_into().unwrap())], strategy: Some(RecommendStrategy::AverageVector), filter: None, params: Some(SearchParams::default()), limit: 100, offset: Some(100), with_payload: Some(WithPayloadInterface::Bool(true)), with_vector: Some(WithVector::Bool(true)), score_threshold: Some(42.0), using: Some(UsingVector::Name("vector".to_string())), lookup_from: Some(LookupLocation { collection: "col2".to_string(), vector: Some("vector".to_string()), shard_key: None, }), }; assert_allowed(&op, &Access::Global(GlobalAccessMode::Manage)); assert_allowed(&op, &Access::Global(GlobalAccessMode::Read)); // Require whole access to col2 assert_forbidden( &op, &AccessCollectionBuilder::new() .add("col", false, false) .add("col2", false, false) .into(), ); assert_allowed_rewrite( &op, &AccessCollectionBuilder::new() .add("col", false, false) .add("col2", false, true) .into(), |op| { op.filter = Some(PayloadConstraint::new_test("col").to_filter()); }, ); // Point ID is used assert_forbidden( &RecommendRequestInternal { positive: vec![RecommendExample::PointId(ExtendedPointId::NumId(12345))], ..op.clone() }, &AccessCollectionBuilder::new() .add("col", false, false) .add("col2", false, true) .into(), ); // lookup_from requires read access assert_forbidden( &op, &AccessCollectionBuilder::new() .add("col", false, true) .into(), ); assert_allowed( &RecommendRequestInternal { lookup_from: None, ..op }, &AccessCollectionBuilder::new() .add("col", false, true) .into(), ); } #[test] fn test_point_request_internal() { let op = PointRequestInternal { ids: vec![PointIdType::NumId(12345)], with_payload: None, with_vector: WithVector::Bool(true), }; assert_allowed(&op, &Access::Global(GlobalAccessMode::Manage)); assert_allowed(&op, &Access::Global(GlobalAccessMode::Read)); assert_forbidden( &op, &AccessCollectionBuilder::new() .add("col", false, false) .into(), ); assert_allowed( &op, &AccessCollectionBuilder::new() .add("col", false, true) .into(), ); } #[test] fn test_core_search_request() { let op = CoreSearchRequest { query: QueryEnum::Nearest(NamedVectorStruct::Default(vec![0.0, 1.0, 2.0])), filter: None, params: Some(SearchParams::default()), limit: 100, offset: 100, with_payload: Some(WithPayloadInterface::Bool(true)), with_vector: Some(WithVector::Bool(true)), score_threshold: Some(42.0), }; assert_allowed(&op, &Access::Global(GlobalAccessMode::Manage)); assert_allowed(&op, &Access::Global(GlobalAccessMode::Read)); assert_allowed( &op, &AccessCollectionBuilder::new() .add("col", false, true) .into(), ); assert_allowed_rewrite( &op, &AccessCollectionBuilder::new() .add("col", false, false) .into(), |op| { op.filter = Some(PayloadConstraint::new_test("col").to_filter()); }, ); } #[test] fn test_count_request_internal() { let op = CountRequestInternal { filter: None, exact: false, }; assert_allowed(&op, &Access::Global(GlobalAccessMode::Manage)); assert_allowed(&op, &Access::Global(GlobalAccessMode::Read)); assert_allowed( &op, &AccessCollectionBuilder::new() .add("col", false, true) .into(), ); assert_allowed_rewrite( &op, &AccessCollectionBuilder::new() .add("col", false, false) .into(), |op| { op.filter = Some(PayloadConstraint::new_test("col").to_filter()); }, ); } #[test] fn test_group_request_source() { let op = GroupRequest { // NOTE: SourceRequest::Recommend is already tested in test_recommend_request_internal source: SourceRequest::Search(SearchRequestInternal { vector: rest::NamedVectorStruct::Default(vec![0.0, 1.0, 2.0]), filter: None, params: Some(SearchParams::default()), limit: 100, offset: Some(100), with_payload: Some(WithPayloadInterface::Bool(true)), with_vector: Some(WithVector::Bool(true)), score_threshold: Some(42.0), }), group_by: "path".parse().unwrap(), group_size: 100, limit: 100, with_lookup: Some(WithLookup { collection_name: "col2".to_string(), with_payload: Some(WithPayloadInterface::Bool(true)), with_vectors: Some(WithVector::Bool(true)), }), }; assert_allowed(&op, &Access::Global(GlobalAccessMode::Manage)); assert_allowed(&op, &Access::Global(GlobalAccessMode::Read)); assert_allowed( &op, &AccessCollectionBuilder::new() .add("col", false, true) .add("col2", false, true) .into(), ); // with_lookup requires whole read access assert_forbidden( &op, &AccessCollectionBuilder::new() .add("col", false, true) .into(), ); assert_forbidden( &op, &AccessCollectionBuilder::new() .add("col", false, true) .add("col", false, false) .into(), ); assert_allowed( &GroupRequest { with_lookup: None, ..op.clone() }, &AccessCollectionBuilder::new() .add("col", false, true) .into(), ); // filter rewrite assert_allowed_rewrite( &op, &AccessCollectionBuilder::new() .add("col", false, false) .add("col2", false, true) .into(), |op| match &mut op.source { SourceRequest::Search(s) => { s.filter = Some(PayloadConstraint::new_test("col").to_filter()); } SourceRequest::Recommend(_) => unreachable!(), SourceRequest::Query(_) => unreachable!(), }, ); } #[test] fn test_discover_request_internal() { let op = DiscoverRequestInternal { target: Some(RecommendExample::Dense(vec![0.0, 1.0, 2.0])), context: Some(vec![ContextExamplePair { positive: RecommendExample::Dense(vec![0.0, 1.0, 2.0]), negative: RecommendExample::Dense(vec![0.0, 1.0, 2.0]), }]), filter: None, params: Some(SearchParams::default()), limit: 100, offset: Some(100), with_payload: Some(WithPayloadInterface::Bool(true)), with_vector: Some(WithVector::Bool(true)), using: Some(UsingVector::Name("vector".to_string())), lookup_from: Some(LookupLocation { collection: "col2".to_string(), vector: Some("vector".to_string()), shard_key: None, }), }; assert_allowed(&op, &Access::Global(GlobalAccessMode::Manage)); assert_allowed(&op, &Access::Global(GlobalAccessMode::Read)); assert_allowed( &op, &AccessCollectionBuilder::new() .add("col", false, true) .add("col2", false, true) .into(), ); assert_allowed_rewrite( &op, &AccessCollectionBuilder::new() .add("col", false, false) .add("col2", false, true) .into(), |op| { op.filter = Some(PayloadConstraint::new_test("col").to_filter()); }, ); // Point ID is used assert_forbidden( &DiscoverRequestInternal { target: Some(RecommendExample::PointId(ExtendedPointId::NumId(12345))), ..op.clone() }, &AccessCollectionBuilder::new() .add("col", false, false) .add("col2", false, true) .into(), ); assert_forbidden( &DiscoverRequestInternal { context: Some(vec![ContextExamplePair { positive: RecommendExample::PointId(ExtendedPointId::NumId(12345)), negative: RecommendExample::Dense(vec![0.0, 1.0, 2.0]), }]), ..op.clone() }, &AccessCollectionBuilder::new() .add("col", false, false) .add("col2", false, true) .into(), ); assert_forbidden( &DiscoverRequestInternal { context: Some(vec![ContextExamplePair { positive: RecommendExample::Dense(vec![0.0, 1.0, 2.0]), negative: RecommendExample::PointId(ExtendedPointId::NumId(12345)), }]), ..op.clone() }, &AccessCollectionBuilder::new() .add("col", false, false) .add("col2", false, true) .into(), ); // lookup_from requires read access assert_forbidden( &op, &AccessCollectionBuilder::new() .add("col", false, true) .into(), ); assert_allowed( &DiscoverRequestInternal { lookup_from: None, ..op }, &AccessCollectionBuilder::new() .add("col", false, true) .into(), ); } #[test] fn test_scroll_request_internal() { let op = ScrollRequestInternal { offset: Some(ExtendedPointId::NumId(12345)), limit: Some(100), filter: None, with_payload: Some(WithPayloadInterface::Bool(true)), with_vector: WithVector::Bool(true), order_by: Some(OrderByInterface::Key("path".parse().unwrap())), }; assert_allowed(&op, &Access::Global(GlobalAccessMode::Manage)); assert_allowed(&op, &Access::Global(GlobalAccessMode::Read)); assert_allowed( &op, &AccessCollectionBuilder::new() .add("col", false, true) .into(), ); assert_allowed_rewrite( &ScrollRequestInternal { ..op }, &AccessCollectionBuilder::new() .add("col", false, false) .into(), |op| { op.filter = Some(PayloadConstraint::new_test("col").to_filter()); }, ); } #[test] fn test_collection_update_operations() { CollectionUpdateOperationsDiscriminants::iter().for_each(|discr| match discr { CollectionUpdateOperationsDiscriminants::PointOperation => { check_collection_update_operations_points() } CollectionUpdateOperationsDiscriminants::VectorOperation => { check_collection_update_operations_update_vectors() } CollectionUpdateOperationsDiscriminants::PayloadOperation => { check_collection_update_operations_payload() } CollectionUpdateOperationsDiscriminants::FieldIndexOperation => { check_collection_update_operations_field_index() } }); } /// Tests for [`CollectionUpdateOperations::PointOperation`]. fn check_collection_update_operations_points() { PointOperationsDiscriminants::iter().for_each(|discr| match discr { PointOperationsDiscriminants::UpsertPoints => { for discr in PointInsertOperationsInternalDiscriminants::iter() { let inner = match discr { PointInsertOperationsInternalDiscriminants::PointsBatch => { PointInsertOperationsInternal::PointsBatch(BatchPersisted { ids: vec![ExtendedPointId::NumId(12345)], vectors: BatchVectorStructPersisted::Single(vec![vec![ 0.0, 1.0, 2.0, ]]), payloads: None, }) } PointInsertOperationsInternalDiscriminants::PointsList => { PointInsertOperationsInternal::PointsList(vec![PointStructPersisted { id: ExtendedPointId::NumId(12345), vector: VectorStructPersisted::Single(vec![0.0, 1.0, 2.0]), payload: None, }]) } }; let op = CollectionUpdateOperations::PointOperation( PointOperations::UpsertPoints(inner), ); assert_requires_whole_write_access(&op); } } PointOperationsDiscriminants::DeletePoints => { let op = CollectionUpdateOperations::PointOperation(PointOperations::DeletePoints { ids: vec![ExtendedPointId::NumId(12345)], }); check_collection_update_operations_delete_points(&op); } PointOperationsDiscriminants::DeletePointsByFilter => { let op = CollectionUpdateOperations::PointOperation( PointOperations::DeletePointsByFilter(make_filter_from_ids(vec![ ExtendedPointId::NumId(12345), ])), ); check_collection_update_operations_delete_points(&op); } PointOperationsDiscriminants::SyncPoints => { let op = CollectionUpdateOperations::PointOperation(PointOperations::SyncPoints( PointSyncOperation { from_id: None, to_id: None, points: Vec::new(), }, )); assert_requires_whole_write_access(&op); } }); } /// Tests for [`CollectionUpdateOperations::PointOperation`] with /// [`PointOperations::DeletePoints`] and [`PointOperations::DeletePointsByFilter`]. fn check_collection_update_operations_delete_points(op: &CollectionUpdateOperations) { assert_allowed(op, &Access::Global(GlobalAccessMode::Manage)); assert_forbidden(op, &Access::Global(GlobalAccessMode::Read)); assert_allowed( op, &AccessCollectionBuilder::new().add("col", true, true).into(), ); assert_forbidden( op, &AccessCollectionBuilder::new() .add("col", false, true) .into(), ); assert_allowed_rewrite( op, &AccessCollectionBuilder::new() .add("col", true, false) .into(), |op| { *op = CollectionUpdateOperations::PointOperation( PointOperations::DeletePointsByFilter( make_filter_from_ids(vec![ExtendedPointId::NumId(12345)]) .merge_owned(PayloadConstraint::new_test("col").to_filter()), ), ); }, ); } /// Tests for [`CollectionUpdateOperations::VectorOperation`]. fn check_collection_update_operations_update_vectors() { VectorOperationsDiscriminants::iter().for_each(|discr| match discr { VectorOperationsDiscriminants::UpdateVectors => { let op = CollectionUpdateOperations::VectorOperation( VectorOperations::UpdateVectors(UpdateVectorsOp { points: vec![PointVectorsPersisted { id: ExtendedPointId::NumId(12345), vector: VectorStructPersisted::Single(vec![0.0, 1.0, 2.0]), }], }), ); assert_requires_whole_write_access(&op); } VectorOperationsDiscriminants::DeleteVectors => { let op = CollectionUpdateOperations::VectorOperation(VectorOperations::DeleteVectors( PointIdsList { points: vec![ExtendedPointId::NumId(12345)], shard_key: None, }, vec!["vector".to_string()], )); check_collection_update_operations_delete_vectors(&op); } VectorOperationsDiscriminants::DeleteVectorsByFilter => { let op = CollectionUpdateOperations::VectorOperation( VectorOperations::DeleteVectorsByFilter( make_filter_from_ids(vec![ExtendedPointId::NumId(12345)]), vec!["vector".to_string()], ), ); check_collection_update_operations_delete_vectors(&op); } }); } /// Tests for [`CollectionUpdateOperations::VectorOperation`] with /// [`VectorOperations::DeleteVectors`] and [`VectorOperations::DeleteVectorsByFilter`]. fn check_collection_update_operations_delete_vectors(op: &CollectionUpdateOperations) { assert_allowed(op, &Access::Global(GlobalAccessMode::Manage)); assert_forbidden(op, &Access::Global(GlobalAccessMode::Read)); assert_allowed( op, &AccessCollectionBuilder::new().add("col", true, true).into(), ); assert_forbidden( op, &AccessCollectionBuilder::new() .add("col", false, true) .into(), ); assert_allowed_rewrite( op, &AccessCollectionBuilder::new() .add("col", true, false) .into(), |op| { *op = CollectionUpdateOperations::VectorOperation( VectorOperations::DeleteVectorsByFilter( make_filter_from_ids(vec![ExtendedPointId::NumId(12345)]) .merge_owned(PayloadConstraint::new_test("col").to_filter()), vec!["vector".to_string()], ), ); }, ); } /// Tests for [`CollectionUpdateOperations::PayloadOperation`]. fn check_collection_update_operations_payload() { for discr in PayloadOpsDiscriminants::iter() { let inner = match discr { PayloadOpsDiscriminants::SetPayload => PayloadOps::SetPayload(SetPayloadOp { payload: Payload::default(), points: Some(vec![ExtendedPointId::NumId(12345)]), filter: None, key: None, }), PayloadOpsDiscriminants::DeletePayload => { PayloadOps::DeletePayload(DeletePayloadOp { keys: vec!["path".parse().unwrap()], points: Some(vec![ExtendedPointId::NumId(12345)]), filter: None, }) } PayloadOpsDiscriminants::ClearPayload => PayloadOps::ClearPayload { points: vec![ExtendedPointId::NumId(12345)], }, PayloadOpsDiscriminants::ClearPayloadByFilter => { PayloadOps::ClearPayloadByFilter(make_filter_from_ids(vec![ ExtendedPointId::NumId(12345), ])) } PayloadOpsDiscriminants::OverwritePayload => { PayloadOps::OverwritePayload(SetPayloadOp { payload: Payload::default(), points: Some(vec![ExtendedPointId::NumId(12345)]), filter: None, key: None, }) } }; let op = CollectionUpdateOperations::PayloadOperation(inner); assert_requires_whole_write_access(&op); } } /// Tests for [`CollectionUpdateOperations::FieldIndexOperation`]. fn check_collection_update_operations_field_index() { for discr in FieldIndexOperationsDiscriminants::iter() { let inner = match discr { FieldIndexOperationsDiscriminants::CreateIndex => { FieldIndexOperations::CreateIndex(CreateIndex { field_name: "path".parse().unwrap(), field_schema: None, }) } FieldIndexOperationsDiscriminants::DeleteIndex => { FieldIndexOperations::DeleteIndex("path".parse().unwrap()) } }; let op = CollectionUpdateOperations::FieldIndexOperation(inner); assert_allowed(&op, &Access::Global(GlobalAccessMode::Manage)); assert_forbidden(&op, &Access::Global(GlobalAccessMode::Read)); assert_forbidden( &op, &AccessCollectionBuilder::new().add("col", true, true).into(), ); assert_forbidden( &op, &AccessCollectionBuilder::new() .add("col", false, true) .into(), ); } } }