use std::collections::BTreeMap; use std::num::NonZeroU32; use std::path::Path; use api::rest::SearchRequestInternal; use collection::collection::Collection; use collection::config::{CollectionConfigInternal, CollectionParams, WalConfig}; use collection::operations::point_ops::{ PointInsertOperationsInternal, PointOperations, PointStructPersisted, VectorStructPersisted, WriteOrdering, }; use collection::operations::shard_selector_internal::ShardSelectorInternal; use collection::operations::types::{ CollectionError, PointRequestInternal, RecommendRequestInternal, VectorsConfig, }; use collection::operations::vector_params_builder::VectorParamsBuilder; use collection::operations::CollectionUpdateOperations; use collection::recommendations::recommend_by; use common::counter::hardware_accumulator::HwMeasurementAcc; use segment::data_types::named_vectors::NamedVectors; use segment::data_types::vectors::{NamedVector, VectorStructInternal}; use segment::types::{Distance, WithPayloadInterface, WithVector}; use tempfile::Builder; use crate::common::{new_local_collection, N_SHARDS, TEST_OPTIMIZERS_CONFIG}; const VEC_NAME1: &str = "vec1"; const VEC_NAME2: &str = "vec2"; #[tokio::test(flavor = "multi_thread")] async fn test_multi_vec() { test_multi_vec_with_shards(1).await; test_multi_vec_with_shards(N_SHARDS).await; } #[cfg(test)] pub async fn multi_vec_collection_fixture(collection_path: &Path, shard_number: u32) -> Collection { let wal_config = WalConfig { wal_capacity_mb: 1, wal_segments_ahead: 0, }; let vector_params1 = VectorParamsBuilder::new(4, Distance::Dot).build(); let vector_params2 = VectorParamsBuilder::new(4, Distance::Dot).build(); let mut vectors_config = BTreeMap::new(); vectors_config.insert(VEC_NAME1.to_string(), vector_params1); vectors_config.insert(VEC_NAME2.to_string(), vector_params2); let collection_params = CollectionParams { vectors: VectorsConfig::Multi(vectors_config), shard_number: NonZeroU32::new(shard_number).expect("Shard number can not be zero"), ..CollectionParams::empty() }; let collection_config = CollectionConfigInternal { params: collection_params, optimizer_config: TEST_OPTIMIZERS_CONFIG.clone(), wal_config, hnsw_config: Default::default(), quantization_config: Default::default(), strict_mode_config: Default::default(), uuid: None, }; let snapshot_path = collection_path.join("snapshots"); // Default to a collection with all the shards local new_local_collection( "test".to_string(), collection_path, &snapshot_path, &collection_config, ) .await .unwrap() } async fn test_multi_vec_with_shards(shard_number: u32) { let collection_dir = Builder::new() .prefix("test_multi_vec_with_shards") .tempdir() .unwrap(); let collection = multi_vec_collection_fixture(collection_dir.path(), shard_number).await; // Upload 1000 random vectors to the collection let mut points = Vec::new(); for i in 0..1000 { let mut vectors = NamedVectors::default(); vectors.insert(VEC_NAME1.to_string(), vec![i as f32, 0.0, 0.0, 0.0].into()); vectors.insert(VEC_NAME2.to_string(), vec![0.0, i as f32, 0.0, 0.0].into()); points.push(PointStructPersisted { id: i.into(), vector: VectorStructPersisted::from(VectorStructInternal::from(vectors)), payload: Some(serde_json::from_str(r#"{"number": "John Doe"}"#).unwrap()), }); } let insert_points = CollectionUpdateOperations::PointOperation(PointOperations::UpsertPoints( PointInsertOperationsInternal::PointsList(points), )); collection .update_from_client_simple(insert_points, true, WriteOrdering::default()) .await .unwrap(); let query_vector = vec![6.0, 0.0, 0.0, 0.0]; let full_search_request = SearchRequestInternal { vector: NamedVector { name: VEC_NAME1.to_string(), vector: query_vector, } .into(), filter: None, limit: 10, offset: None, with_payload: Some(WithPayloadInterface::Bool(true)), with_vector: Some(true.into()), params: None, score_threshold: None, }; let hw_acc = HwMeasurementAcc::new(); let result = collection .search( full_search_request.into(), None, &ShardSelectorInternal::All, None, &hw_acc, ) .await .unwrap(); hw_acc.discard(); for hit in result { match hit.vector.unwrap() { VectorStructInternal::Single(_) => panic!("expected multi vector"), VectorStructInternal::MultiDense(_) => panic!("expected multi vector"), VectorStructInternal::Named(vectors) => { assert!(vectors.contains_key(VEC_NAME1)); assert!(vectors.contains_key(VEC_NAME2)); } } } let query_vector = vec![0.0, 2.0, 0.0, 0.0]; let failed_search_request = SearchRequestInternal { vector: query_vector.clone().into(), filter: None, limit: 10, offset: None, with_payload: Some(WithPayloadInterface::Bool(true)), with_vector: Some(true.into()), params: None, score_threshold: None, }; let hw_acc = HwMeasurementAcc::new(); let result = collection .search( failed_search_request.into(), None, &ShardSelectorInternal::All, None, &hw_acc, ) .await; hw_acc.discard(); assert!( matches!(result, Err(CollectionError::BadInput { .. })), "{result:?}" ); let full_search_request = SearchRequestInternal { vector: NamedVector { name: VEC_NAME2.to_string(), vector: query_vector, } .into(), filter: None, limit: 10, offset: None, with_payload: Some(WithPayloadInterface::Bool(true)), with_vector: Some(true.into()), params: None, score_threshold: None, }; let hw_acc = HwMeasurementAcc::new(); let result = collection .search( full_search_request.into(), None, &ShardSelectorInternal::All, None, &hw_acc, ) .await .unwrap(); hw_acc.discard(); for hit in result { match hit.vector.unwrap() { VectorStructInternal::Single(_) => panic!("expected multi vector"), VectorStructInternal::MultiDense(_) => panic!("expected multi vector"), VectorStructInternal::Named(vectors) => { assert!(vectors.contains_key(VEC_NAME1)); assert!(vectors.contains_key(VEC_NAME2)); } } } let retrieve = collection .retrieve( PointRequestInternal { ids: vec![6.into()], with_payload: Some(WithPayloadInterface::Bool(false)), with_vector: WithVector::Selector(vec![VEC_NAME1.to_string()]), }, None, &ShardSelectorInternal::All, None, ) .await .unwrap(); assert_eq!(retrieve.len(), 1); match retrieve[0].vector.as_ref().unwrap() { VectorStructInternal::Single(_) => panic!("expected multi vector"), VectorStructInternal::MultiDense(_) => panic!("expected multi vector"), VectorStructInternal::Named(vectors) => { assert!(vectors.contains_key(VEC_NAME1)); assert!(!vectors.contains_key(VEC_NAME2)); } } let hw_acc = HwMeasurementAcc::new(); let recommend_result = recommend_by( RecommendRequestInternal { positive: vec![6.into()], with_payload: Some(WithPayloadInterface::Bool(false)), with_vector: Some(WithVector::Selector(vec![VEC_NAME2.to_string()])), limit: 10, ..Default::default() }, &collection, |_name| async { unreachable!("should not be called in this test") }, None, ShardSelectorInternal::All, None, &hw_acc, ) .await; hw_acc.discard(); match recommend_result { Ok(_) => panic!("Error expected"), Err(err) => match err { CollectionError::BadRequest { .. } => {} CollectionError::BadInput { .. } => {} error => panic!("Unexpected error {error}"), }, } let hw_acc = HwMeasurementAcc::new(); let recommend_result = recommend_by( RecommendRequestInternal { positive: vec![6.into()], with_payload: Some(WithPayloadInterface::Bool(false)), with_vector: Some(WithVector::Selector(vec![VEC_NAME2.to_string()])), limit: 10, using: Some(VEC_NAME1.to_string().into()), ..Default::default() }, &collection, |_name| async { unreachable!("should not be called in this test") }, None, ShardSelectorInternal::All, None, &hw_acc, ) .await .unwrap(); hw_acc.discard(); assert_eq!(recommend_result.len(), 10); for hit in recommend_result { match hit.vector.as_ref().unwrap() { VectorStructInternal::Single(_) => panic!("expected multi vector"), VectorStructInternal::MultiDense(_) => panic!("expected multi vector"), VectorStructInternal::Named(vectors) => { assert!(!vectors.contains_key(VEC_NAME1)); assert!(vectors.contains_key(VEC_NAME2)); } } } }