Spaces:
Build error
Build error
use std::collections::{BTreeSet, HashMap}; | |
use std::sync::atomic::AtomicBool; | |
use std::sync::Arc; | |
use atomic_refcell::AtomicRefCell; | |
use common::cpu::CpuPermit; | |
use common::types::ScoredPointOffset; | |
use itertools::Itertools; | |
use rand::prelude::StdRng; | |
use rand::{Rng, SeedableRng}; | |
use rstest::rstest; | |
use segment::data_types::vectors::{ | |
only_default_vector, DenseVector, QueryVector, DEFAULT_VECTOR_NAME, | |
}; | |
use segment::entry::entry_point::SegmentEntry; | |
use segment::fixtures::payload_fixtures::{random_dense_byte_vector, random_int_payload}; | |
use segment::index::hnsw_index::graph_links::GraphLinksRam; | |
use segment::index::hnsw_index::hnsw::{HNSWIndex, HnswIndexOpenArgs}; | |
use segment::index::hnsw_index::num_rayon_threads; | |
use segment::index::{PayloadIndex, VectorIndex}; | |
use segment::segment_constructor::build_segment; | |
use segment::types::{ | |
BinaryQuantizationConfig, CompressionRatio, Condition, Distance, FieldCondition, Filter, | |
HnswConfig, Indexes, Payload, PayloadSchemaType, ProductQuantizationConfig, | |
QuantizationSearchParams, Range, ScalarQuantizationConfig, SearchParams, SegmentConfig, | |
SeqNumberType, VectorDataConfig, VectorStorageDatatype, VectorStorageType, | |
}; | |
use segment::vector_storage::quantized::quantized_vectors::QuantizedVectors; | |
use segment::vector_storage::query::{ContextPair, DiscoveryQuery, RecoQuery}; | |
use segment::vector_storage::VectorStorageEnum; | |
use serde_json::json; | |
use tempfile::Builder; | |
const MAX_EXAMPLE_PAIRS: usize = 4; | |
enum QueryVariant { | |
Nearest, | |
RecommendBestScore, | |
Discovery, | |
} | |
enum QuantizationVariant { | |
Scalar, | |
PQ, | |
Binary, | |
} | |
fn random_vector<R>(rnd_gen: &mut R, dim: usize, data_type: VectorStorageDatatype) -> DenseVector | |
where | |
R: Rng + ?Sized, | |
{ | |
match data_type { | |
VectorStorageDatatype::Float32 => unreachable!(), | |
VectorStorageDatatype::Float16 => { | |
let mut vector = segment::fixtures::payload_fixtures::random_vector(rnd_gen, dim); | |
vector.iter_mut().for_each(|x| *x -= 0.5); | |
vector | |
} | |
VectorStorageDatatype::Uint8 => random_dense_byte_vector(rnd_gen, dim), | |
} | |
} | |
fn random_discovery_query<R: Rng + ?Sized>( | |
rnd: &mut R, | |
dim: usize, | |
data_type: VectorStorageDatatype, | |
) -> QueryVector { | |
let num_pairs: usize = rnd.gen_range(1..MAX_EXAMPLE_PAIRS); | |
let target = random_vector(rnd, dim, data_type).into(); | |
let pairs = (0..num_pairs) | |
.map(|_| { | |
let positive = random_vector(rnd, dim, data_type).into(); | |
let negative = random_vector(rnd, dim, data_type).into(); | |
ContextPair { positive, negative } | |
}) | |
.collect_vec(); | |
DiscoveryQuery::new(target, pairs).into() | |
} | |
fn random_reco_query<R: Rng + ?Sized>( | |
rnd: &mut R, | |
dim: usize, | |
data_type: VectorStorageDatatype, | |
) -> QueryVector { | |
let num_examples: usize = rnd.gen_range(1..MAX_EXAMPLE_PAIRS); | |
let positive = (0..num_examples) | |
.map(|_| random_vector(rnd, dim, data_type).into()) | |
.collect_vec(); | |
let negative = (0..num_examples) | |
.map(|_| random_vector(rnd, dim, data_type).into()) | |
.collect_vec(); | |
RecoQuery::new(positive, negative).into() | |
} | |
fn random_query<R: Rng + ?Sized>( | |
variant: &QueryVariant, | |
rnd: &mut R, | |
dim: usize, | |
data_type: VectorStorageDatatype, | |
) -> QueryVector { | |
match variant { | |
QueryVariant::Nearest => random_vector(rnd, dim, data_type).into(), | |
QueryVariant::Discovery => random_discovery_query(rnd, dim, data_type), | |
QueryVariant::RecommendBestScore => random_reco_query(rnd, dim, data_type), | |
} | |
} | |
fn sames_count(a: &[Vec<ScoredPointOffset>], b: &[Vec<ScoredPointOffset>]) -> usize { | |
a[0].iter() | |
.map(|x| x.idx) | |
.collect::<BTreeSet<_>>() | |
.intersection(&b[0].iter().map(|x| x.idx).collect()) | |
.count() | |
} | |
fn test_byte_storage_binary_quantization_hnsw( | |
query_variant: QueryVariant, | |
storage_data_type: VectorStorageDatatype, | |
quantization_variant: QuantizationVariant, | |
distance: Distance, | |
usize, | dim:|
usize, | ef:|
f64, // out of 100 | min_acc:|
) { | |
use segment::json_path::JsonPath; | |
let stopped = AtomicBool::new(false); | |
let m = 8; | |
let num_vectors: u64 = 5_000; | |
let ef_construct = 16; | |
let full_scan_threshold = 16; // KB | |
let num_payload_values = 2; | |
let mut rnd = StdRng::seed_from_u64(42); | |
let dir_byte = Builder::new().prefix("segment_dir_byte").tempdir().unwrap(); | |
let quantized_data_path = dir_byte.path(); | |
let hnsw_dir_byte = Builder::new().prefix("hnsw_dir_byte").tempdir().unwrap(); | |
let config_byte = SegmentConfig { | |
vector_data: HashMap::from([( | |
DEFAULT_VECTOR_NAME.to_owned(), | |
VectorDataConfig { | |
size: dim, | |
distance, | |
storage_type: VectorStorageType::Memory, | |
index: Indexes::Plain {}, | |
quantization_config: None, | |
multivector_config: None, | |
datatype: Some(storage_data_type), | |
}, | |
)]), | |
sparse_vector_data: Default::default(), | |
payload_storage_type: Default::default(), | |
}; | |
let int_key = "int"; | |
let mut segment_byte = build_segment(dir_byte.path(), &config_byte, true).unwrap(); | |
// check that `segment_byte` uses byte or half storage | |
{ | |
let borrowed_storage = segment_byte.vector_data[DEFAULT_VECTOR_NAME] | |
.vector_storage | |
.borrow(); | |
let raw_storage: &VectorStorageEnum = &borrowed_storage; | |
assert!( | |
matches!(raw_storage, &VectorStorageEnum::DenseSimpleByte(_)) | |
| matches!(raw_storage, &VectorStorageEnum::DenseSimpleHalf(_)) | |
); | |
} | |
for n in 0..num_vectors { | |
let idx = n.into(); | |
let vector = random_vector(&mut rnd, dim, storage_data_type); | |
let int_payload = random_int_payload(&mut rnd, num_payload_values..=num_payload_values); | |
let payload: Payload = json!({int_key:int_payload,}).into(); | |
segment_byte | |
.upsert_point(n as SeqNumberType, idx, only_default_vector(&vector)) | |
.unwrap(); | |
segment_byte | |
.set_full_payload(n as SeqNumberType, idx, &payload) | |
.unwrap(); | |
} | |
segment_byte | |
.payload_index | |
.borrow_mut() | |
.set_indexed(&JsonPath::new(int_key), PayloadSchemaType::Integer) | |
.unwrap(); | |
let quantization_config = match quantization_variant { | |
QuantizationVariant::Scalar => ScalarQuantizationConfig { | |
r#type: Default::default(), | |
quantile: None, | |
always_ram: None, | |
} | |
.into(), | |
QuantizationVariant::PQ => ProductQuantizationConfig { | |
compression: CompressionRatio::X8, | |
always_ram: None, | |
} | |
.into(), | |
QuantizationVariant::Binary => BinaryQuantizationConfig { always_ram: None }.into(), | |
}; | |
segment_byte | |
.vector_data | |
.values_mut() | |
.for_each(|vector_storage| { | |
let quantized_vectors = QuantizedVectors::create( | |
&vector_storage.vector_storage.borrow(), | |
&quantization_config, | |
quantized_data_path, | |
4, | |
&stopped, | |
) | |
.unwrap(); | |
vector_storage.quantized_vectors = | |
Arc::new(AtomicRefCell::new(Some(quantized_vectors))); | |
}); | |
let hnsw_config = HnswConfig { | |
m, | |
ef_construct, | |
full_scan_threshold, | |
max_indexing_threads: 2, | |
on_disk: Some(false), | |
payload_m: None, | |
}; | |
let permit_cpu_count = num_rayon_threads(hnsw_config.max_indexing_threads); | |
let permit = Arc::new(CpuPermit::dummy(permit_cpu_count as u32)); | |
let hnsw_index_byte = HNSWIndex::<GraphLinksRam>::open(HnswIndexOpenArgs { | |
path: hnsw_dir_byte.path(), | |
id_tracker: segment_byte.id_tracker.clone(), | |
vector_storage: segment_byte.vector_data[DEFAULT_VECTOR_NAME] | |
.vector_storage | |
.clone(), | |
quantized_vectors: segment_byte.vector_data[DEFAULT_VECTOR_NAME] | |
.quantized_vectors | |
.clone(), | |
payload_index: segment_byte.payload_index.clone(), | |
hnsw_config, | |
permit: Some(permit), | |
stopped: &stopped, | |
}) | |
.unwrap(); | |
let top = 5; | |
let mut sames = 0; | |
let attempts = 100; | |
for _ in 0..attempts { | |
let query = random_query(&query_variant, &mut rnd, dim, storage_data_type); | |
let range_size = 40; | |
let left_range = rnd.gen_range(0..400); | |
let right_range = left_range + range_size; | |
let filter = Filter::new_must(Condition::Field(FieldCondition::new_range( | |
JsonPath::new(int_key), | |
Range { | |
lt: None, | |
gt: None, | |
gte: Some(f64::from(left_range)), | |
lte: Some(f64::from(right_range)), | |
}, | |
))); | |
let filter_query = Some(&filter); | |
let index_result_byte = hnsw_index_byte | |
.search( | |
&[&query], | |
filter_query, | |
top, | |
Some(&SearchParams { | |
hnsw_ef: Some(ef), | |
quantization: Some(QuantizationSearchParams { | |
oversampling: Some(2.0), | |
..Default::default() | |
}), | |
..Default::default() | |
}), | |
&Default::default(), | |
) | |
.unwrap(); | |
let plain_result_byte = hnsw_index_byte | |
.search( | |
&[&query], | |
filter_query, | |
top, | |
Some(&SearchParams { | |
hnsw_ef: Some(ef), | |
quantization: Some(QuantizationSearchParams { | |
ignore: true, | |
..Default::default() | |
}), | |
exact: true, | |
..Default::default() | |
}), | |
&Default::default(), | |
) | |
.unwrap(); | |
sames += sames_count(&plain_result_byte, &index_result_byte); | |
} | |
let acc = 100.0 * sames as f64 / (attempts * top) as f64; | |
println!("sames = {sames}, attempts = {attempts}, top = {top}, acc = {acc}"); | |
assert!(acc > min_acc); | |
} | |