Spaces:
Build error
Build error
use std::collections::{HashMap, HashSet}; | |
use std::num::NonZeroU32; | |
use std::sync::Arc; | |
use api::rest::OrderByInterface; | |
use common::counter::hardware_accumulator::HwMeasurementAcc; | |
use common::cpu::CpuBudget; | |
use rand::{thread_rng, Rng}; | |
use segment::data_types::vectors::NamedVectorStruct; | |
use segment::types::{ | |
Distance, ExtendedPointId, Payload, PayloadFieldSchema, PayloadSchemaType, SearchParams, | |
}; | |
use serde_json::{Map, Value}; | |
use tempfile::Builder; | |
use crate::collection::{Collection, RequestShardTransfer}; | |
use crate::config::{CollectionConfigInternal, CollectionParams, WalConfig}; | |
use crate::operations::point_ops::{ | |
PointInsertOperationsInternal, PointOperations, PointStructPersisted, VectorStructPersisted, | |
}; | |
use crate::operations::query_enum::QueryEnum; | |
use crate::operations::shard_selector_internal::ShardSelectorInternal; | |
use crate::operations::shared_storage_config::SharedStorageConfig; | |
use crate::operations::types::{ | |
CoreSearchRequest, PointRequestInternal, ScrollRequestInternal, VectorsConfig, | |
}; | |
use crate::operations::vector_params_builder::VectorParamsBuilder; | |
use crate::operations::{CollectionUpdateOperations, OperationWithClockTag}; | |
use crate::optimizers_builder::OptimizersConfig; | |
use crate::shards::channel_service::ChannelService; | |
use crate::shards::collection_shard_distribution::CollectionShardDistribution; | |
use crate::shards::replica_set::{AbortShardTransfer, ChangePeerFromState, ReplicaState}; | |
use crate::shards::shard::{PeerId, ShardId}; | |
const DIM: u64 = 4; | |
const PEER_ID: u64 = 1; | |
const SHARD_COUNT: u32 = 4; | |
const DUPLICATE_POINT_ID: ExtendedPointId = ExtendedPointId::NumId(100); | |
/// Create the collection used for deduplication tests. | |
async fn fixture() -> Collection { | |
let wal_config = WalConfig { | |
wal_capacity_mb: 1, | |
wal_segments_ahead: 0, | |
}; | |
let collection_params = CollectionParams { | |
vectors: VectorsConfig::Single(VectorParamsBuilder::new(DIM, Distance::Dot).build()), | |
shard_number: NonZeroU32::new(SHARD_COUNT).unwrap(), | |
replication_factor: NonZeroU32::new(1).unwrap(), | |
write_consistency_factor: NonZeroU32::new(1).unwrap(), | |
..CollectionParams::empty() | |
}; | |
let config = CollectionConfigInternal { | |
params: collection_params, | |
optimizer_config: OptimizersConfig::fixture(), | |
wal_config, | |
hnsw_config: Default::default(), | |
quantization_config: Default::default(), | |
strict_mode_config: Default::default(), | |
uuid: None, | |
}; | |
let collection_dir = Builder::new().prefix("test_collection").tempdir().unwrap(); | |
let snapshots_path = Builder::new().prefix("test_snapshots").tempdir().unwrap(); | |
let collection_name = "test".to_string(); | |
let shards: HashMap<ShardId, HashSet<PeerId>> = (0..SHARD_COUNT) | |
.map(|i| (i, HashSet::from([PEER_ID]))) | |
.collect(); | |
let storage_config: SharedStorageConfig = SharedStorageConfig::default(); | |
let storage_config = Arc::new(storage_config); | |
let collection = Collection::new( | |
collection_name.clone(), | |
PEER_ID, | |
collection_dir.path(), | |
snapshots_path.path(), | |
&config, | |
storage_config.clone(), | |
CollectionShardDistribution { shards }, | |
ChannelService::default(), | |
dummy_on_replica_failure(), | |
dummy_request_shard_transfer(), | |
dummy_abort_shard_transfer(), | |
None, | |
None, | |
CpuBudget::default(), | |
None, | |
) | |
.await | |
.unwrap(); | |
// Create payload index to allow order by | |
collection | |
.create_payload_index( | |
"num".parse().unwrap(), | |
PayloadFieldSchema::FieldType(PayloadSchemaType::Integer), | |
) | |
.await | |
.expect("failed to create payload index"); | |
// Insert two points into all shards directly, a point matching the shard ID, and point 100 | |
// We insert into all shards directly to prevent spreading points by the hashring | |
// We insert the same point into multiple shards on purpose | |
let mut rng = thread_rng(); | |
for (shard_id, shard) in collection.shards_holder().write().await.get_shards() { | |
let op = OperationWithClockTag::from(CollectionUpdateOperations::PointOperation( | |
PointOperations::UpsertPoints(PointInsertOperationsInternal::PointsList(vec![ | |
PointStructPersisted { | |
id: u64::from(shard_id).into(), | |
vector: VectorStructPersisted::Single( | |
(0..DIM).map(|_| rng.gen_range(0.0..1.0)).collect(), | |
), | |
payload: Some(Payload(Map::from_iter([( | |
"num".to_string(), | |
Value::from(-(shard_id as i32)), | |
)]))), | |
}, | |
PointStructPersisted { | |
id: DUPLICATE_POINT_ID, | |
vector: VectorStructPersisted::Single( | |
(0..DIM).map(|_| rng.gen_range(0.0..1.0)).collect(), | |
), | |
payload: Some(Payload(Map::from_iter([( | |
"num".to_string(), | |
Value::from(100 - shard_id as i32), | |
)]))), | |
}, | |
])), | |
)); | |
shard | |
.update_local(op, true) | |
.await | |
.expect("failed to insert points"); | |
} | |
// Activate all shards | |
for shard_id in 0..SHARD_COUNT { | |
collection | |
.set_shard_replica_state(shard_id, PEER_ID, ReplicaState::Active, None) | |
.await | |
.expect("failed to active shard"); | |
} | |
collection | |
} | |
async fn test_scroll_dedup() { | |
let collection = fixture().await; | |
// Scroll all points without ordering | |
let result = collection | |
.scroll_by( | |
ScrollRequestInternal { | |
offset: None, | |
limit: Some(usize::MAX), | |
filter: None, | |
with_payload: Some(false.into()), | |
with_vector: false.into(), | |
order_by: None, | |
}, | |
None, | |
&ShardSelectorInternal::All, | |
None, | |
) | |
.await | |
.expect("failed to search"); | |
assert!(!result.points.is_empty(), "expected some points"); | |
let mut seen = HashSet::new(); | |
for point_id in result.points.iter().map(|point| point.id) { | |
assert!( | |
seen.insert(point_id), | |
"got point id {point_id} more than once, they should be deduplicated", | |
); | |
} | |
// Scroll all points with ordering | |
let result = collection | |
.scroll_by( | |
ScrollRequestInternal { | |
offset: None, | |
limit: Some(usize::MAX), | |
filter: None, | |
with_payload: Some(false.into()), | |
with_vector: false.into(), | |
order_by: Some(OrderByInterface::Key("num".parse().unwrap())), | |
}, | |
None, | |
&ShardSelectorInternal::All, | |
None, | |
) | |
.await | |
.expect("failed to search"); | |
assert!(!result.points.is_empty(), "expected some points"); | |
let mut seen = HashSet::new(); | |
for point_id in result.points.iter().map(|point| point.id) { | |
assert!( | |
seen.insert(point_id), | |
"got point id {point_id:?} more than once, they should be deduplicated", | |
); | |
} | |
} | |
async fn test_retrieve_dedup() { | |
let collection = fixture().await; | |
let records = collection | |
.retrieve( | |
PointRequestInternal { | |
ids: (0..u64::from(SHARD_COUNT)) | |
.map(ExtendedPointId::from) | |
.chain([DUPLICATE_POINT_ID]) | |
.collect(), | |
with_payload: Some(false.into()), | |
with_vector: false.into(), | |
}, | |
None, | |
&ShardSelectorInternal::All, | |
None, | |
) | |
.await | |
.expect("failed to search"); | |
assert!(!records.is_empty(), "expected some records"); | |
let mut seen = HashSet::new(); | |
for point_id in records.iter().map(|record| record.id) { | |
assert!( | |
seen.insert(point_id), | |
"got point id {point_id} more than once, they should be deduplicated", | |
); | |
} | |
} | |
async fn test_search_dedup() { | |
let collection = fixture().await; | |
let hw_acc = HwMeasurementAcc::new(); | |
let points = collection | |
.search( | |
CoreSearchRequest { | |
query: QueryEnum::Nearest(NamedVectorStruct::Default(vec![0.1, 0.2, 0.3, 0.4])), | |
filter: None, | |
params: Some(SearchParams { | |
exact: true, | |
..Default::default() | |
}), | |
limit: 100, | |
offset: 0, | |
with_payload: None, | |
with_vector: None, | |
score_threshold: None, | |
}, | |
None, | |
&ShardSelectorInternal::All, | |
None, | |
&hw_acc, | |
) | |
.await | |
.expect("failed to search"); | |
assert!(!points.is_empty(), "expected some points"); | |
hw_acc.discard(); | |
let mut seen = HashSet::new(); | |
for point_id in points.iter().map(|point| point.id) { | |
assert!( | |
seen.insert(point_id), | |
"got point id {point_id} more than once, they should be deduplicated", | |
); | |
} | |
} | |
pub fn dummy_on_replica_failure() -> ChangePeerFromState { | |
Arc::new(move |_peer_id, _shard_id, _from_state| {}) | |
} | |
pub fn dummy_request_shard_transfer() -> RequestShardTransfer { | |
Arc::new(move |_transfer| {}) | |
} | |
pub fn dummy_abort_shard_transfer() -> AbortShardTransfer { | |
Arc::new(|_transfer, _reason| {}) | |
} | |