Spaces:
Build error
Build error
use std::cmp::{max, min, Ordering}; | |
use std::sync::atomic::AtomicBool; | |
use std::sync::atomic::Ordering::Relaxed; | |
use common::counter::hardware_counter::HardwareCounterCell; | |
use common::top_k::TopK; | |
use common::types::{PointOffsetType, ScoredPointOffset}; | |
use super::posting_list_common::PostingListIter; | |
use crate::common::scores_memory_pool::PooledScoresHandle; | |
use crate::common::sparse_vector::RemappedSparseVector; | |
use crate::common::types::{DimId, DimWeight}; | |
use crate::index::inverted_index::InvertedIndex; | |
use crate::index::posting_list::PostingListIterator; | |
/// Iterator over posting lists with a reference to the corresponding query index and weight | |
pub struct IndexedPostingListIterator<T: PostingListIter> { | |
posting_list_iterator: T, | |
query_index: DimId, | |
query_weight: DimWeight, | |
} | |
/// Making this larger makes the search faster but uses more (pooled) memory | |
const ADVANCE_BATCH_SIZE: usize = 10_000; | |
pub struct SearchContext<'a, 'b, T: PostingListIter = PostingListIterator<'a>> { | |
postings_iterators: Vec<IndexedPostingListIterator<T>>, | |
query: RemappedSparseVector, | |
top: usize, | |
is_stopped: &'a AtomicBool, | |
top_results: TopK, | |
min_record_id: Option<PointOffsetType>, // min_record_id ids across all posting lists | |
max_record_id: PointOffsetType, // max_record_id ids across all posting lists | |
pooled: PooledScoresHandle<'b>, // handle to pooled scores | |
use_pruning: bool, | |
hardware_counter: HardwareCounterCell, | |
} | |
impl<'a, 'b, T: PostingListIter> SearchContext<'a, 'b, T> { | |
pub fn new( | |
query: RemappedSparseVector, | |
top: usize, | |
inverted_index: &'a impl InvertedIndex<Iter<'a> = T>, | |
pooled: PooledScoresHandle<'b>, | |
is_stopped: &'a AtomicBool, | |
) -> SearchContext<'a, 'b, T> { | |
let mut postings_iterators = Vec::new(); | |
// track min and max record ids across all posting lists | |
let mut max_record_id = 0; | |
let mut min_record_id = u32::MAX; | |
// iterate over query indices | |
for (query_weight_offset, id) in query.indices.iter().enumerate() { | |
if let Some(mut it) = inverted_index.get(id) { | |
if let (Some(first), Some(last_id)) = (it.peek(), it.last_id()) { | |
// check if new min | |
let min_record_id_posting = first.record_id; | |
min_record_id = min(min_record_id, min_record_id_posting); | |
// check if new max | |
let max_record_id_posting = last_id; | |
max_record_id = max(max_record_id, max_record_id_posting); | |
// capture query info | |
let query_index = *id; | |
let query_weight = query.values[query_weight_offset]; | |
postings_iterators.push(IndexedPostingListIterator { | |
posting_list_iterator: it, | |
query_index, | |
query_weight, | |
}); | |
} | |
} | |
} | |
let top_results = TopK::new(top); | |
// Query vectors with negative values can NOT use the pruning mechanism which relies on the pre-computed `max_next_weight`. | |
// The max contribution per posting list that we calculate is not made to compute the max value of two negative numbers. | |
// This is a limitation of the current pruning implementation. | |
let use_pruning = T::reliable_max_next_weight() && query.values.iter().all(|v| *v >= 0.0); | |
let min_record_id = Some(min_record_id); | |
SearchContext { | |
postings_iterators, | |
query, | |
top, | |
is_stopped, | |
top_results, | |
min_record_id, | |
max_record_id, | |
pooled, | |
use_pruning, | |
hardware_counter: HardwareCounterCell::new(), | |
} | |
} | |
/// Plain search against the given ids without any pruning | |
pub fn plain_search(&mut self, ids: &[PointOffsetType]) -> Vec<ScoredPointOffset> { | |
// sort ids to fully leverage posting list iterator traversal | |
let mut sorted_ids = ids.to_vec(); | |
sorted_ids.sort_unstable(); | |
let cpu_counter = self.hardware_counter.cpu_counter_mut(); | |
for id in sorted_ids { | |
// check for cancellation | |
if self.is_stopped.load(Relaxed) { | |
break; | |
} | |
let mut indices = Vec::with_capacity(self.query.indices.len()); | |
let mut values = Vec::with_capacity(self.query.values.len()); | |
// collect indices and values for the current record id from the query's posting lists *only* | |
for posting_iterator in self.postings_iterators.iter_mut() { | |
// rely on underlying binary search as the posting lists are sorted by record id | |
match posting_iterator.posting_list_iterator.skip_to(id) { | |
None => {} // no match for posting list | |
Some(element) => { | |
// match for posting list | |
indices.push(posting_iterator.query_index); | |
values.push(element.weight); | |
} | |
} | |
} | |
// Accumulate the sum of the length of the retrieved sparse vector and the query vector length | |
// as measurement for CPU usage of plain search. | |
cpu_counter.incr_delta_mut(indices.len() + self.query.indices.len()); | |
// reconstruct sparse vector and score against query | |
let sparse_vector = RemappedSparseVector { indices, values }; | |
self.top_results.push(ScoredPointOffset { | |
score: sparse_vector.score(&self.query).unwrap_or(0.0), | |
idx: id, | |
}); | |
} | |
let top = std::mem::take(&mut self.top_results); | |
top.into_vec() | |
} | |
/// Advance posting lists iterators in a batch fashion. | |
fn advance_batch<F: Fn(PointOffsetType) -> bool>( | |
&mut self, | |
batch_start_id: PointOffsetType, | |
batch_last_id: PointOffsetType, | |
filter_condition: &F, | |
) { | |
// init batch scores | |
let batch_len = batch_last_id - batch_start_id + 1; | |
self.pooled.scores.clear(); // keep underlying allocated memory | |
self.pooled.scores.resize(batch_len as usize, 0.0); | |
for posting in self.postings_iterators.iter_mut() { | |
posting.posting_list_iterator.for_each_till_id( | |
batch_last_id, | |
self.pooled.scores.as_mut_slice(), | |
|scores, id, weight| { | |
let element_score = weight * posting.query_weight; | |
let local_id = (id - batch_start_id) as usize; | |
// SAFETY: `id` is within `batch_start_id..=batch_last_id` | |
// Thus, `local_id` is within `0..batch_len`. | |
*unsafe { scores.get_unchecked_mut(local_id) } += element_score; | |
}, | |
); | |
} | |
for (local_index, &score) in self.pooled.scores.iter().enumerate() { | |
// publish only the non-zero scores above the current min to beat | |
if score != 0.0 && score > self.top_results.threshold() { | |
let real_id = batch_start_id + local_index as PointOffsetType; | |
// do not score if filter condition is not satisfied | |
if !filter_condition(real_id) { | |
continue; | |
} | |
let score_point_offset = ScoredPointOffset { | |
score, | |
idx: real_id, | |
}; | |
self.top_results.push(score_point_offset); | |
} | |
} | |
} | |
/// Compute scores for the last posting list quickly | |
fn process_last_posting_list<F: Fn(PointOffsetType) -> bool>(&mut self, filter_condition: &F) { | |
debug_assert_eq!(self.postings_iterators.len(), 1); | |
let posting = &mut self.postings_iterators[0]; | |
posting.posting_list_iterator.for_each_till_id( | |
PointOffsetType::MAX, | |
&mut (), | |
|_, id, weight| { | |
// do not score if filter condition is not satisfied | |
if !filter_condition(id) { | |
return; | |
} | |
let score = weight * posting.query_weight; | |
self.top_results.push(ScoredPointOffset { score, idx: id }); | |
}, | |
); | |
} | |
/// Returns the next min record id from all posting list iterators | |
/// | |
/// returns None if all posting list iterators are exhausted | |
fn next_min_id(to_inspect: &mut [IndexedPostingListIterator<T>]) -> Option<PointOffsetType> { | |
let mut min_record_id = None; | |
// Iterate to find min record id at the head of the posting lists | |
for posting_iterator in to_inspect.iter_mut() { | |
if let Some(next_element) = posting_iterator.posting_list_iterator.peek() { | |
match min_record_id { | |
None => min_record_id = Some(next_element.record_id), // first record with matching id | |
Some(min_id_seen) => { | |
// update min record id if smaller | |
if next_element.record_id < min_id_seen { | |
min_record_id = Some(next_element.record_id); | |
} | |
} | |
} | |
} | |
} | |
min_record_id | |
} | |
/// Make sure the longest posting list is at the head of the posting list iterators | |
fn promote_longest_posting_lists_to_the_front(&mut self) { | |
// find index of longest posting list | |
let posting_index = self | |
.postings_iterators | |
.iter() | |
.enumerate() | |
.max_by(|(_, a), (_, b)| { | |
a.posting_list_iterator | |
.len_to_end() | |
.cmp(&b.posting_list_iterator.len_to_end()) | |
}) | |
.map(|(index, _)| index); | |
if let Some(posting_index) = posting_index { | |
// make sure it is not already at the head | |
if posting_index != 0 { | |
// swap longest posting list to the head | |
self.postings_iterators.swap(0, posting_index); | |
} | |
} | |
} | |
/// Search for the top k results that satisfy the filter condition | |
pub fn search<F: Fn(PointOffsetType) -> bool>( | |
&mut self, | |
filter_condition: &F, | |
) -> Vec<ScoredPointOffset> { | |
if self.postings_iterators.is_empty() { | |
return Vec::new(); | |
} | |
{ | |
// Measure CPU usage of indexed sparse search. | |
// Assume the complexity of the search as total volume of the posting lists | |
// that are traversed in the batched search. | |
let cpu_counter = self.hardware_counter.cpu_counter_mut(); | |
for posting in self.postings_iterators.iter() { | |
cpu_counter.incr_delta_mut(posting.posting_list_iterator.len_to_end()); | |
} | |
} | |
let mut best_min_score = f32::MIN; | |
loop { | |
// check for cancellation (atomic amortized by batch) | |
if self.is_stopped.load(Relaxed) { | |
break; | |
} | |
// prepare next iterator of batched ids | |
let Some(start_batch_id) = self.min_record_id else { | |
break; | |
}; | |
// compute batch range of contiguous ids for the next batch | |
let last_batch_id = min( | |
start_batch_id + ADVANCE_BATCH_SIZE as u32, | |
self.max_record_id, | |
); | |
// advance and score posting lists iterators | |
self.advance_batch(start_batch_id, last_batch_id, filter_condition); | |
// remove empty posting lists if necessary | |
self.postings_iterators.retain(|posting_iterator| { | |
posting_iterator.posting_list_iterator.len_to_end() != 0 | |
}); | |
// update min_record_id | |
self.min_record_id = Self::next_min_id(&mut self.postings_iterators); | |
// check if all posting lists are exhausted | |
if self.postings_iterators.is_empty() { | |
break; | |
} | |
// if only one posting list left, we can score it quickly | |
if self.postings_iterators.len() == 1 { | |
self.process_last_posting_list(filter_condition); | |
break; | |
} | |
// we potentially have enough results to prune low performing posting lists | |
if self.use_pruning && self.top_results.len() >= self.top { | |
// current min score | |
let new_min_score = self.top_results.threshold(); | |
if new_min_score == best_min_score { | |
// no improvement in lowest best score since last pruning - skip pruning | |
continue; | |
} else { | |
best_min_score = new_min_score; | |
} | |
// make sure the first posting list is the longest for pruning | |
self.promote_longest_posting_lists_to_the_front(); | |
// prune posting list that cannot possibly contribute to the top results | |
let pruned = self.prune_longest_posting_list(new_min_score); | |
if pruned { | |
// update min_record_id | |
self.min_record_id = Self::next_min_id(&mut self.postings_iterators); | |
} | |
} | |
} | |
// posting iterators exhausted, return result queue | |
let queue = std::mem::take(&mut self.top_results); | |
queue.into_vec() | |
} | |
/// Prune posting lists that cannot possibly contribute to the top results | |
/// Assumes longest posting list is at the head of the posting list iterators | |
/// Returns true if the longest posting list was pruned | |
pub fn prune_longest_posting_list(&mut self, min_score: f32) -> bool { | |
if self.postings_iterators.is_empty() { | |
return false; | |
} | |
// peek first element of longest posting list | |
let (longest_posting_iterator, rest_iterators) = self.postings_iterators.split_at_mut(1); | |
let longest_posting_iterator = &mut longest_posting_iterator[0]; | |
if let Some(element) = longest_posting_iterator.posting_list_iterator.peek() { | |
let next_min_id_in_others = Self::next_min_id(rest_iterators); | |
match next_min_id_in_others { | |
Some(next_min_id) => { | |
match next_min_id.cmp(&element.record_id) { | |
Ordering::Equal => { | |
// if the next min id in the other posting lists is the same as the current one, | |
// we can't prune the current element as it needs to be scored properly across posting lists | |
return false; | |
} | |
Ordering::Less => { | |
// we can't prune as there the other posting lists contains smaller smaller ids that need to scored first | |
return false; | |
} | |
Ordering::Greater => { | |
// next_min_id is > element.record_id there is a chance to prune up to `next_min_id` | |
// check against the max possible score using the `max_next_weight` | |
// we can under prune as we should actually check the best score up to `next_min_id` - 1 only | |
// instead of the max possible score but it is not possible to know the best score up to `next_min_id` - 1 | |
let max_weight_from_list = element.weight.max(element.max_next_weight); | |
let max_score_contribution = | |
max_weight_from_list * longest_posting_iterator.query_weight; | |
if max_score_contribution <= min_score { | |
// prune to next_min_id | |
let longest_posting_iterator = | |
&mut self.postings_iterators[0].posting_list_iterator; | |
let position_before_pruning = | |
longest_posting_iterator.current_index(); | |
longest_posting_iterator.skip_to(next_min_id); | |
let position_after_pruning = | |
longest_posting_iterator.current_index(); | |
// check if pruning took place | |
return position_before_pruning != position_after_pruning; | |
} | |
} | |
} | |
} | |
None => { | |
// the current posting list is the only one left, we can potentially skip it to the end | |
// check against the max possible score using the `max_next_weight` | |
let max_weight_from_list = element.weight.max(element.max_next_weight); | |
let max_score_contribution = | |
max_weight_from_list * longest_posting_iterator.query_weight; | |
if max_score_contribution <= min_score { | |
// prune to the end! | |
let longest_posting_iterator = &mut self.postings_iterators[0]; | |
longest_posting_iterator.posting_list_iterator.skip_to_end(); | |
return true; | |
} | |
} | |
} | |
} | |
// no pruning took place | |
false | |
} | |
/// Return the current hardware measurement counter. | |
pub fn take_hardware_counter(&self) -> HardwareCounterCell { | |
self.hardware_counter.take() | |
} | |
} | |
mod tests { | |
use std::any::TypeId; | |
use std::borrow::Cow; | |
use std::sync::OnceLock; | |
use rand::Rng; | |
use tempfile::TempDir; | |
use super::*; | |
use crate::common::scores_memory_pool::ScoresMemoryPool; | |
use crate::common::sparse_vector::SparseVector; | |
use crate::common::sparse_vector_fixture::random_sparse_vector; | |
use crate::common::types::QuantizedU8; | |
use crate::index::inverted_index::inverted_index_compressed_immutable_ram::InvertedIndexCompressedImmutableRam; | |
use crate::index::inverted_index::inverted_index_compressed_mmap::InvertedIndexCompressedMmap; | |
use crate::index::inverted_index::inverted_index_immutable_ram::InvertedIndexImmutableRam; | |
use crate::index::inverted_index::inverted_index_mmap::InvertedIndexMmap; | |
use crate::index::inverted_index::inverted_index_ram::InvertedIndexRam; | |
use crate::index::inverted_index::inverted_index_ram_builder::InvertedIndexBuilder; | |
// ---- Test instantiations ---- | |
mod ram {} | |
mod mmap {} | |
mod iram {} | |
mod iram_f32 {} | |
mod iram_f16 {} | |
mod iram_u8 {} | |
mod iram_q8 {} | |
mod mmap_f32 {} | |
mod mmap_f16 {} | |
mod mmap_u8 {} | |
mod mmap_q8 {} | |
// --- End of test instantiations --- | |
static TEST_SCORES_POOL: OnceLock<ScoresMemoryPool> = OnceLock::new(); | |
fn get_pooled_scores() -> PooledScoresHandle<'static> { | |
TEST_SCORES_POOL | |
.get_or_init(ScoresMemoryPool::default) | |
.get() | |
} | |
/// Match all filter condition for testing | |
fn match_all(_p: PointOffsetType) -> bool { | |
true | |
} | |
/// Helper struct to store both an index and a temporary directory | |
struct TestIndex<I: InvertedIndex> { | |
index: I, | |
temp_dir: TempDir, | |
} | |
impl<I: InvertedIndex> TestIndex<I> { | |
fn from_ram(ram_index: InvertedIndexRam) -> Self { | |
let temp_dir = tempfile::Builder::new() | |
.prefix("test_index_dir") | |
.tempdir() | |
.unwrap(); | |
TestIndex { | |
index: I::from_ram_index(Cow::Owned(ram_index), &temp_dir).unwrap(), | |
temp_dir, | |
} | |
} | |
} | |
/// Round scores to allow some quantization errors | |
fn round_scores<I: 'static>(mut scores: Vec<ScoredPointOffset>) -> Vec<ScoredPointOffset> { | |
let errors_allowed_for = [ | |
TypeId::of::<InvertedIndexCompressedImmutableRam<QuantizedU8>>(), | |
TypeId::of::<InvertedIndexCompressedMmap<QuantizedU8>>(), | |
]; | |
if errors_allowed_for.contains(&TypeId::of::<I>()) { | |
let precision = 0.25; | |
scores.iter_mut().for_each(|score| { | |
score.score = (score.score / precision).round() * precision; | |
}); | |
scores | |
} else { | |
scores | |
} | |
} | |
fn test_empty_query<I: InvertedIndex>() { | |
let index = TestIndex::<I>::from_ram(InvertedIndexRam::empty()); | |
let is_stopped = AtomicBool::new(false); | |
let mut search_context = SearchContext::new( | |
RemappedSparseVector::default(), // empty query vector | |
10, | |
&index.index, | |
get_pooled_scores(), | |
&is_stopped, | |
); | |
assert_eq!(search_context.search(&match_all), Vec::new()); | |
} | |
fn search_test<I: InvertedIndex>() { | |
let index = TestIndex::<I>::from_ram({ | |
let mut builder = InvertedIndexBuilder::new(); | |
builder.add(1, [(1, 10.0), (2, 10.0), (3, 10.0)].into()); | |
builder.add(2, [(1, 20.0), (2, 20.0), (3, 20.0)].into()); | |
builder.add(3, [(1, 30.0), (2, 30.0), (3, 30.0)].into()); | |
builder.build() | |
}); | |
let is_stopped = AtomicBool::new(false); | |
let mut search_context = SearchContext::new( | |
RemappedSparseVector { | |
indices: vec![1, 2, 3], | |
values: vec![1.0, 1.0, 1.0], | |
}, | |
10, | |
&index.index, | |
get_pooled_scores(), | |
&is_stopped, | |
); | |
assert_eq!( | |
round_scores::<I>(search_context.search(&match_all)), | |
vec![ | |
ScoredPointOffset { | |
score: 90.0, | |
idx: 3 | |
}, | |
ScoredPointOffset { | |
score: 60.0, | |
idx: 2 | |
}, | |
ScoredPointOffset { | |
score: 30.0, | |
idx: 1 | |
}, | |
] | |
); | |
// len(QueryVector)=3 * len(vector)=3 => 3*3 => 9 | |
let counter = search_context.take_hardware_counter(); | |
assert_eq!(counter.cpu_counter().get(), 9); | |
counter.discard_results(); | |
} | |
fn search_with_update_test<I: InvertedIndex + 'static>() { | |
if TypeId::of::<I>() != TypeId::of::<InvertedIndexRam>() { | |
// Only InvertedIndexRam supports upserts | |
return; | |
} | |
let mut index = TestIndex::<I>::from_ram({ | |
let mut builder = InvertedIndexBuilder::new(); | |
builder.add(1, [(1, 10.0), (2, 10.0), (3, 10.0)].into()); | |
builder.add(2, [(1, 20.0), (2, 20.0), (3, 20.0)].into()); | |
builder.add(3, [(1, 30.0), (2, 30.0), (3, 30.0)].into()); | |
builder.build() | |
}); | |
let is_stopped = AtomicBool::new(false); | |
let mut search_context = SearchContext::new( | |
RemappedSparseVector { | |
indices: vec![1, 2, 3], | |
values: vec![1.0, 1.0, 1.0], | |
}, | |
10, | |
&index.index, | |
get_pooled_scores(), | |
&is_stopped, | |
); | |
assert_eq!( | |
round_scores::<I>(search_context.search(&match_all)), | |
vec![ | |
ScoredPointOffset { | |
score: 90.0, | |
idx: 3 | |
}, | |
ScoredPointOffset { | |
score: 60.0, | |
idx: 2 | |
}, | |
ScoredPointOffset { | |
score: 30.0, | |
idx: 1 | |
}, | |
] | |
); | |
search_context.take_hardware_counter().discard_results(); | |
drop(search_context); | |
// update index with new point | |
index.index.upsert( | |
4, | |
RemappedSparseVector { | |
indices: vec![1, 2, 3], | |
values: vec![40.0, 40.0, 40.0], | |
}, | |
None, | |
); | |
let mut search_context = SearchContext::new( | |
RemappedSparseVector { | |
indices: vec![1, 2, 3], | |
values: vec![1.0, 1.0, 1.0], | |
}, | |
10, | |
&index.index, | |
get_pooled_scores(), | |
&is_stopped, | |
); | |
assert_eq!( | |
search_context.search(&match_all), | |
vec![ | |
ScoredPointOffset { | |
score: 120.0, | |
idx: 4 | |
}, | |
ScoredPointOffset { | |
score: 90.0, | |
idx: 3 | |
}, | |
ScoredPointOffset { | |
score: 60.0, | |
idx: 2 | |
}, | |
ScoredPointOffset { | |
score: 30.0, | |
idx: 1 | |
}, | |
] | |
); | |
search_context.take_hardware_counter().discard_results(); | |
} | |
fn search_with_hot_key_test<I: InvertedIndex>() { | |
let index = TestIndex::<I>::from_ram({ | |
let mut builder = InvertedIndexBuilder::new(); | |
builder.add(1, [(1, 10.0), (2, 10.0), (3, 10.0)].into()); | |
builder.add(2, [(1, 20.0), (2, 20.0), (3, 20.0)].into()); | |
builder.add(3, [(1, 30.0), (2, 30.0), (3, 30.0)].into()); | |
builder.add(4, [(1, 1.0)].into()); | |
builder.add(5, [(1, 2.0)].into()); | |
builder.add(6, [(1, 3.0)].into()); | |
builder.add(7, [(1, 4.0)].into()); | |
builder.add(8, [(1, 5.0)].into()); | |
builder.add(9, [(1, 6.0)].into()); | |
builder.build() | |
}); | |
let is_stopped = AtomicBool::new(false); | |
let mut search_context = SearchContext::new( | |
RemappedSparseVector { | |
indices: vec![1, 2, 3], | |
values: vec![1.0, 1.0, 1.0], | |
}, | |
3, | |
&index.index, | |
get_pooled_scores(), | |
&is_stopped, | |
); | |
assert_eq!( | |
round_scores::<I>(search_context.search(&match_all)), | |
vec![ | |
ScoredPointOffset { | |
score: 90.0, | |
idx: 3 | |
}, | |
ScoredPointOffset { | |
score: 60.0, | |
idx: 2 | |
}, | |
ScoredPointOffset { | |
score: 30.0, | |
idx: 1 | |
}, | |
] | |
); | |
// [ID=1] (Retrieve all 9 Vectors) => 9 | |
// [ID=2] (Retrieve 1-3) => 3 | |
// [ID=3] (Retrieve 1-3) => 3 | |
// 3 + 3 + 9 => 15 | |
assert_eq!(search_context.hardware_counter.cpu_counter().get(), 15); | |
search_context.take_hardware_counter().discard_results(); | |
let mut search_context = SearchContext::new( | |
RemappedSparseVector { | |
indices: vec![1, 2, 3], | |
values: vec![1.0, 1.0, 1.0], | |
}, | |
4, | |
&index.index, | |
get_pooled_scores(), | |
&is_stopped, | |
); | |
assert_eq!( | |
round_scores::<I>(search_context.search(&match_all)), | |
vec![ | |
ScoredPointOffset { | |
score: 90.0, | |
idx: 3 | |
}, | |
ScoredPointOffset { | |
score: 60.0, | |
idx: 2 | |
}, | |
ScoredPointOffset { | |
score: 30.0, | |
idx: 1 | |
}, | |
ScoredPointOffset { score: 6.0, idx: 9 }, | |
] | |
); | |
// No difference to previous calculation because it's the same amount of score | |
// calculations when increasing the "top" parameter. | |
assert_eq!(search_context.hardware_counter.cpu_counter().get(), 15); | |
search_context.take_hardware_counter().discard_results(); | |
} | |
fn pruning_single_to_end_test<I: InvertedIndex>() { | |
let index = TestIndex::<I>::from_ram({ | |
let mut builder = InvertedIndexBuilder::new(); | |
builder.add(1, [(1, 10.0)].into()); | |
builder.add(2, [(1, 20.0)].into()); | |
builder.add(3, [(1, 30.0)].into()); | |
builder.build() | |
}); | |
let is_stopped = AtomicBool::new(false); | |
let mut search_context = SearchContext::new( | |
RemappedSparseVector { | |
indices: vec![1, 2, 3], | |
values: vec![1.0, 1.0, 1.0], | |
}, | |
1, | |
&index.index, | |
get_pooled_scores(), | |
&is_stopped, | |
); | |
// assuming we have gathered enough results and want to prune the longest posting list | |
assert!(search_context.prune_longest_posting_list(30.0)); | |
// the longest posting list was pruned to the end | |
assert_eq!( | |
search_context.postings_iterators[0] | |
.posting_list_iterator | |
.len_to_end(), | |
0 | |
); | |
} | |
fn pruning_multi_to_end_test<I: InvertedIndex>() { | |
let index = TestIndex::<I>::from_ram({ | |
let mut builder = InvertedIndexBuilder::new(); | |
builder.add(1, [(1, 10.0)].into()); | |
builder.add(2, [(1, 20.0)].into()); | |
builder.add(3, [(1, 30.0)].into()); | |
builder.add(5, [(3, 10.0)].into()); | |
builder.add(6, [(2, 20.0), (3, 20.0)].into()); | |
builder.add(7, [(2, 30.0), (3, 30.0)].into()); | |
builder.build() | |
}); | |
let is_stopped = AtomicBool::new(false); | |
let mut search_context = SearchContext::new( | |
RemappedSparseVector { | |
indices: vec![1, 2, 3], | |
values: vec![1.0, 1.0, 1.0], | |
}, | |
1, | |
&index.index, | |
get_pooled_scores(), | |
&is_stopped, | |
); | |
// assuming we have gathered enough results and want to prune the longest posting list | |
assert!(search_context.prune_longest_posting_list(30.0)); | |
// the longest posting list was pruned to the end | |
assert_eq!( | |
search_context.postings_iterators[0] | |
.posting_list_iterator | |
.len_to_end(), | |
0 | |
); | |
} | |
fn pruning_multi_under_prune_test<I: InvertedIndex>() { | |
if !I::Iter::reliable_max_next_weight() { | |
return; | |
} | |
let index = TestIndex::<I>::from_ram({ | |
let mut builder = InvertedIndexBuilder::new(); | |
builder.add(1, [(1, 10.0)].into()); | |
builder.add(2, [(1, 20.0)].into()); | |
builder.add(3, [(1, 20.0)].into()); | |
builder.add(4, [(1, 10.0)].into()); | |
builder.add(5, [(3, 10.0)].into()); | |
builder.add(6, [(1, 20.0), (2, 20.0), (3, 20.0)].into()); | |
builder.add(7, [(1, 40.0), (2, 30.0), (3, 30.0)].into()); | |
builder.build() | |
}); | |
let is_stopped = AtomicBool::new(false); | |
let mut search_context = SearchContext::new( | |
RemappedSparseVector { | |
indices: vec![1, 2, 3], | |
values: vec![1.0, 1.0, 1.0], | |
}, | |
1, | |
&index.index, | |
get_pooled_scores(), | |
&is_stopped, | |
); | |
// one would expect this to prune up to `6` but it does not happen it practice because we are under pruning by design | |
// we should actually check the best score up to `6` - 1 only instead of the max possible score (40.0) | |
assert!(!search_context.prune_longest_posting_list(30.0)); | |
assert!(search_context.prune_longest_posting_list(40.0)); | |
// the longest posting list was pruned to the end | |
assert_eq!( | |
search_context.postings_iterators[0] | |
.posting_list_iterator | |
.len_to_end(), | |
2 // 6, 7 | |
); | |
} | |
/// Generates a random inverted index with `num_vectors` vectors | |
fn random_inverted_index<R: Rng + ?Sized>( | |
rnd_gen: &mut R, | |
num_vectors: u32, | |
max_sparse_dimension: usize, | |
) -> InvertedIndexRam { | |
let mut inverted_index_ram = InvertedIndexRam::empty(); | |
for i in 1..=num_vectors { | |
let SparseVector { indices, values } = | |
random_sparse_vector(rnd_gen, max_sparse_dimension); | |
let vector = RemappedSparseVector::new(indices, values).unwrap(); | |
inverted_index_ram.upsert(i, vector, None); | |
} | |
inverted_index_ram | |
} | |
fn promote_longest_test<I: InvertedIndex>() { | |
let index = TestIndex::<I>::from_ram({ | |
let mut builder = InvertedIndexBuilder::new(); | |
builder.add(1, [(1, 10.0), (2, 10.0), (3, 10.0)].into()); | |
builder.add(2, [(1, 20.0), (3, 20.0)].into()); | |
builder.add(3, [(2, 30.0), (3, 30.0)].into()); | |
builder.build() | |
}); | |
let is_stopped = AtomicBool::new(false); | |
let mut search_context = SearchContext::new( | |
RemappedSparseVector { | |
indices: vec![1, 2, 3], | |
values: vec![1.0, 1.0, 1.0], | |
}, | |
3, | |
&index.index, | |
get_pooled_scores(), | |
&is_stopped, | |
); | |
assert_eq!( | |
search_context.postings_iterators[0] | |
.posting_list_iterator | |
.len_to_end(), | |
2 | |
); | |
search_context.promote_longest_posting_lists_to_the_front(); | |
assert_eq!( | |
search_context.postings_iterators[0] | |
.posting_list_iterator | |
.len_to_end(), | |
3 | |
); | |
} | |
fn plain_search_all_test<I: InvertedIndex>() { | |
let index = TestIndex::<I>::from_ram({ | |
let mut builder = InvertedIndexBuilder::new(); | |
builder.add(1, [(1, 10.0), (2, 10.0), (3, 10.0)].into()); | |
builder.add(2, [(1, 20.0), (3, 20.0)].into()); | |
builder.add(3, [(1, 30.0), (3, 30.0)].into()); | |
builder.build() | |
}); | |
let is_stopped = AtomicBool::new(false); | |
let mut search_context = SearchContext::new( | |
RemappedSparseVector { | |
indices: vec![1, 2, 3], | |
values: vec![1.0, 1.0, 1.0], | |
}, | |
3, | |
&index.index, | |
get_pooled_scores(), | |
&is_stopped, | |
); | |
let scores = search_context.plain_search(&[1, 3, 2]); | |
assert_eq!( | |
round_scores::<I>(scores), | |
vec![ | |
ScoredPointOffset { | |
idx: 3, | |
score: 60.0 | |
}, | |
ScoredPointOffset { | |
idx: 2, | |
score: 40.0 | |
}, | |
ScoredPointOffset { | |
idx: 1, | |
score: 30.0 | |
}, | |
] | |
); | |
// [ID=1] (Retrieve three sparse vectors (1,2,3)) + QueryLength=3 => 6 | |
// [ID=2] (Retrieve two sparse vectors (1,3)) + QueryLength=3 => 5 | |
// [ID=3] (Retrieve two sparse vectors (1,3)) + QueryLength=3 => 5 | |
// 6 + 5 + 5 => 16 | |
let hardware_counter = search_context.take_hardware_counter(); | |
assert_eq!(hardware_counter.cpu_counter().get(), 16); | |
hardware_counter.discard_results(); | |
} | |
fn plain_search_gap_test<I: InvertedIndex>() { | |
let index = TestIndex::<I>::from_ram({ | |
let mut builder = InvertedIndexBuilder::new(); | |
builder.add(1, [(1, 10.0), (2, 10.0), (3, 10.0)].into()); | |
builder.add(2, [(1, 20.0), (3, 20.0)].into()); | |
builder.add(3, [(2, 30.0), (3, 30.0)].into()); | |
builder.build() | |
}); | |
// query vector has a gap for dimension 2 | |
let is_stopped = AtomicBool::new(false); | |
let mut search_context = SearchContext::new( | |
RemappedSparseVector { | |
indices: vec![1, 3], | |
values: vec![1.0, 1.0], | |
}, | |
3, | |
&index.index, | |
get_pooled_scores(), | |
&is_stopped, | |
); | |
let scores = search_context.plain_search(&[1, 2, 3]); | |
assert_eq!( | |
round_scores::<I>(scores), | |
vec![ | |
ScoredPointOffset { | |
idx: 2, | |
score: 40.0 | |
}, | |
ScoredPointOffset { | |
idx: 3, | |
score: 30.0 // the dimension 2 did not contribute to the score | |
}, | |
ScoredPointOffset { | |
idx: 1, | |
score: 20.0 // the dimension 2 did not contribute to the score | |
}, | |
] | |
); | |
// [ID=1] (Retrieve two sparse vectors (1,2)) + QueryLength=2 => 4 | |
// [ID=2] (Retrieve two sparse vectors (1,3)) + QueryLength=2 => 4 | |
// [ID=3] (Retrieve one sparse vector (3)) + QueryLength=2 => 3 | |
// 4 + 4 + 3 => 11 | |
let hardware_counter = search_context.take_hardware_counter(); | |
assert_eq!(hardware_counter.cpu_counter().get(), 11); | |
hardware_counter.discard_results(); | |
} | |
} | |