colibri.qdrant / lib /segment /tests /integration /multivector_hnsw_test.rs
Gouzi Mohaled
Ajout du dossier lib
84d2a97
use std::collections::HashMap;
use std::sync::atomic::AtomicBool;
use std::sync::Arc;
use atomic_refcell::AtomicRefCell;
use common::cpu::CpuPermit;
use rand::prelude::StdRng;
use rand::SeedableRng;
use segment::common::rocksdb_wrapper::{open_db, DB_VECTOR_CF};
use segment::data_types::vectors::{
only_default_vector, MultiDenseVectorInternal, QueryVector, TypedMultiDenseVectorRef,
VectorElementType, VectorRef, DEFAULT_VECTOR_NAME,
};
use segment::entry::entry_point::SegmentEntry;
use segment::fixtures::index_fixtures::random_vector;
use segment::fixtures::payload_fixtures::random_int_payload;
use segment::index::hnsw_index::graph_links::GraphLinksRam;
use segment::index::hnsw_index::hnsw::{HNSWIndex, HnswIndexOpenArgs};
use segment::index::VectorIndex;
use segment::json_path::JsonPath;
use segment::segment_constructor::build_segment;
use segment::spaces::metric::Metric;
use segment::spaces::simple::{CosineMetric, DotProductMetric, EuclidMetric, ManhattanMetric};
use segment::types::{
Condition, Distance, FieldCondition, Filter, HnswConfig, Indexes, MultiVectorConfig, Payload,
PayloadSchemaType, SegmentConfig, SeqNumberType, VectorDataConfig, VectorStorageType,
};
use segment::vector_storage::multi_dense::simple_multi_dense_vector_storage::open_simple_multi_dense_vector_storage;
use segment::vector_storage::VectorStorage;
use serde_json::json;
use tempfile::Builder;
#[test]
fn test_single_multi_and_dense_hnsw_equivalency() {
let num_vectors: u64 = 1_000;
let distance = Distance::Cosine;
let num_payload_values = 2;
let dim = 8;
let mut rnd = StdRng::seed_from_u64(42);
let dir = Builder::new().prefix("segment_dir").tempdir().unwrap();
let config = SegmentConfig {
vector_data: HashMap::from([(
DEFAULT_VECTOR_NAME.to_owned(),
VectorDataConfig {
size: dim,
distance,
storage_type: VectorStorageType::Memory,
index: Indexes::Plain {},
quantization_config: None,
multivector_config: None,
datatype: None,
},
)]),
sparse_vector_data: Default::default(),
payload_storage_type: Default::default(),
};
let int_key = "int";
let mut segment = build_segment(dir.path(), &config, true).unwrap();
segment
.create_field_index(
0,
&JsonPath::new(int_key),
Some(&PayloadSchemaType::Integer.into()),
)
.unwrap();
let dir = Builder::new().prefix("storage_dir").tempdir().unwrap();
let db = open_db(dir.path(), &[DB_VECTOR_CF]).unwrap();
let mut multi_storage = open_simple_multi_dense_vector_storage(
db,
DB_VECTOR_CF,
dim,
distance,
MultiVectorConfig::default(),
&AtomicBool::new(false),
)
.unwrap();
for n in 0..num_vectors {
let idx = n.into();
let vector = random_vector(&mut rnd, dim);
let preprocessed_vector = match distance {
Distance::Cosine => {
<CosineMetric as Metric<VectorElementType>>::preprocess(vector.clone())
}
Distance::Euclid => {
<EuclidMetric as Metric<VectorElementType>>::preprocess(vector.clone())
}
Distance::Dot => {
<DotProductMetric as Metric<VectorElementType>>::preprocess(vector.clone())
}
Distance::Manhattan => {
<ManhattanMetric as Metric<VectorElementType>>::preprocess(vector.clone())
}
};
let vector_multi = MultiDenseVectorInternal::new(preprocessed_vector, vector.len());
let int_payload = random_int_payload(&mut rnd, num_payload_values..=num_payload_values);
let payload: Payload = json!({int_key:int_payload,}).into();
segment
.upsert_point(n as SeqNumberType, idx, only_default_vector(&vector))
.unwrap();
segment
.set_full_payload(n as SeqNumberType, idx, &payload)
.unwrap();
let internal_id = segment.id_tracker.borrow().internal_id(idx).unwrap();
multi_storage
.insert_vector(
internal_id,
VectorRef::MultiDense(TypedMultiDenseVectorRef::from(&vector_multi)),
)
.unwrap();
}
let hnsw_dir = Builder::new().prefix("hnsw_dir").tempdir().unwrap();
let stopped = AtomicBool::new(false);
let m = 8;
let ef_construct = 100;
let full_scan_threshold = 10000;
let hnsw_config = HnswConfig {
m,
ef_construct,
full_scan_threshold,
max_indexing_threads: 2,
on_disk: Some(false),
payload_m: None,
};
// single threaded mode to guarantee equivalency between single and multi hnsw
let permit = Arc::new(CpuPermit::dummy(1));
let vector_storage = &segment.vector_data[DEFAULT_VECTOR_NAME].vector_storage;
let quantized_vectors = &segment.vector_data[DEFAULT_VECTOR_NAME].quantized_vectors;
let hnsw_index_dense = HNSWIndex::<GraphLinksRam>::open(HnswIndexOpenArgs {
path: hnsw_dir.path(),
id_tracker: segment.id_tracker.clone(),
vector_storage: vector_storage.clone(),
quantized_vectors: quantized_vectors.clone(),
payload_index: segment.payload_index.clone(),
hnsw_config: hnsw_config.clone(),
permit: Some(permit.clone()),
stopped: &stopped,
})
.unwrap();
let multi_storage = Arc::new(AtomicRefCell::new(multi_storage));
let hnsw_index_multi = HNSWIndex::<GraphLinksRam>::open(HnswIndexOpenArgs {
path: hnsw_dir.path(),
id_tracker: segment.id_tracker.clone(),
vector_storage: multi_storage,
quantized_vectors: quantized_vectors.clone(),
payload_index: segment.payload_index.clone(),
hnsw_config,
permit: Some(permit),
stopped: &stopped,
})
.unwrap();
for _ in 0..10 {
let random_vector = random_vector(&mut rnd, dim);
let query_vector = random_vector.clone().into();
let query_vector_multi = QueryVector::Nearest(vec![random_vector].try_into().unwrap());
let payload_value = random_int_payload(&mut rnd, 1..=1).pop().unwrap();
let filter = Filter::new_must(Condition::Field(FieldCondition::new_match(
JsonPath::new(int_key),
payload_value.into(),
)));
let search_res_dense = hnsw_index_dense
.search(
&[&query_vector],
Some(&filter),
10,
None,
&Default::default(),
)
.unwrap();
let search_res_multi = hnsw_index_multi
.search(
&[&query_vector_multi],
Some(&filter),
10,
None,
&Default::default(),
)
.unwrap();
assert_eq!(search_res_dense, search_res_multi);
}
}