use std::collections::HashMap; use std::sync::atomic::AtomicBool; use std::sync::Arc; use common::cpu::CpuPermit; use common::types::PointOffsetType; use itertools::Itertools; use rand::{thread_rng, Rng}; use segment::data_types::vectors::{only_default_vector, DEFAULT_VECTOR_NAME}; use segment::entry::entry_point::SegmentEntry; use segment::fixtures::payload_fixtures::{random_int_payload, random_vector}; use segment::index::hnsw_index::graph_links::GraphLinksRam; use segment::index::hnsw_index::hnsw::{HNSWIndex, HnswIndexOpenArgs}; use segment::index::hnsw_index::num_rayon_threads; use segment::index::{PayloadIndex, VectorIndex}; use segment::json_path::JsonPath; use segment::segment_constructor::build_segment; use segment::types::{ Condition, Distance, FieldCondition, Filter, HnswConfig, Indexes, Payload, PayloadSchemaType, Range, SearchParams, SegmentConfig, SeqNumberType, VectorDataConfig, VectorStorageType, }; use serde_json::json; use tempfile::Builder; #[test] fn exact_search_test() { let stopped = AtomicBool::new(false); let dim = 8; let m = 8; let num_vectors: u64 = 5_000; let ef = 32; let ef_construct = 16; let distance = Distance::Cosine; let full_scan_threshold = 16; // KB let indexing_threshold = 500; // num vectors let num_payload_values = 2; let mut rnd = thread_rng(); let dir = Builder::new().prefix("segment_dir").tempdir().unwrap(); let hnsw_dir = Builder::new().prefix("hnsw_dir").tempdir().unwrap(); let config = SegmentConfig { vector_data: HashMap::from([( DEFAULT_VECTOR_NAME.to_owned(), VectorDataConfig { size: dim, distance, storage_type: VectorStorageType::Memory, index: Indexes::Plain {}, quantization_config: None, multivector_config: None, datatype: None, }, )]), sparse_vector_data: Default::default(), payload_storage_type: Default::default(), }; let int_key = "int"; let mut segment = build_segment(dir.path(), &config, true).unwrap(); for n in 0..num_vectors { let idx = n.into(); let vector = random_vector(&mut rnd, dim); let int_payload = random_int_payload(&mut rnd, num_payload_values..=num_payload_values); let payload: Payload = json!({int_key:int_payload,}).into(); segment .upsert_point(n as SeqNumberType, idx, only_default_vector(&vector)) .unwrap(); segment .set_full_payload(n as SeqNumberType, idx, &payload) .unwrap(); } // let opnum = num_vectors + 1; let payload_index_ptr = segment.payload_index.clone(); let hnsw_config = HnswConfig { m, ef_construct, full_scan_threshold, max_indexing_threads: 2, on_disk: Some(false), payload_m: None, }; payload_index_ptr .borrow_mut() .set_indexed(&JsonPath::new(int_key), PayloadSchemaType::Integer) .unwrap(); let borrowed_payload_index = payload_index_ptr.borrow(); let blocks = borrowed_payload_index .payload_blocks(&JsonPath::new(int_key), indexing_threshold) .collect_vec(); for block in blocks.iter() { assert!( block.condition.range.is_some(), "only range conditions should be generated for this type of payload" ); } let mut coverage: HashMap = Default::default(); for block in &blocks { let px = payload_index_ptr.borrow(); let filter = Filter::new_must(Condition::Field(block.condition.clone())); let points = px.query_points(&filter); for point in points { coverage.insert(point, coverage.get(&point).unwrap_or(&0) + 1); } } let expected_blocks = num_vectors as usize / indexing_threshold * 2; eprintln!("blocks.len() = {:#?}", blocks.len()); assert!( (blocks.len() as i64 - expected_blocks as i64).abs() <= 3, "real number of payload blocks is too far from expected" ); assert_eq!( coverage.len(), num_vectors as usize, "not all points are covered by payload blocks" ); let permit_cpu_count = num_rayon_threads(hnsw_config.max_indexing_threads); let permit = Arc::new(CpuPermit::dummy(permit_cpu_count as u32)); let hnsw_index = HNSWIndex::::open(HnswIndexOpenArgs { path: hnsw_dir.path(), id_tracker: segment.id_tracker.clone(), vector_storage: segment.vector_data[DEFAULT_VECTOR_NAME] .vector_storage .clone(), quantized_vectors: segment.vector_data[DEFAULT_VECTOR_NAME] .quantized_vectors .clone(), payload_index: payload_index_ptr.clone(), hnsw_config, permit: Some(permit), stopped: &stopped, }) .unwrap(); let top = 3; let attempts = 50; for _i in 0..attempts { let query = random_vector(&mut rnd, dim).into(); let index_result = hnsw_index .search( &[&query], None, top, Some(&SearchParams { hnsw_ef: Some(ef), exact: true, ..Default::default() }), &Default::default(), ) .unwrap(); let plain_result = segment.vector_data[DEFAULT_VECTOR_NAME] .vector_index .borrow() .search(&[&query], None, top, None, &Default::default()) .unwrap(); assert_eq!( index_result, plain_result, "Exact search is not equal to plain search" ); let range_size = 40; let left_range = rnd.gen_range(0..400); let right_range = left_range + range_size; let filter = Filter::new_must(Condition::Field(FieldCondition::new_range( JsonPath::new(int_key), Range { lt: None, gt: None, gte: Some(f64::from(left_range)), lte: Some(f64::from(right_range)), }, ))); let filter_query = Some(&filter); let index_result = hnsw_index .search( &[&query], filter_query, top, Some(&SearchParams { hnsw_ef: Some(ef), exact: true, ..Default::default() }), &Default::default(), ) .unwrap(); let plain_result = segment.vector_data[DEFAULT_VECTOR_NAME] .vector_index .borrow() .search(&[&query], filter_query, top, None, &Default::default()) .unwrap(); assert_eq!( index_result, plain_result, "Exact search is not equal to plain search" ); } }