Gouzi Mohaled
Ajout du dossier lib
84d2a97
raw
history blame contribute delete
14.8 kB
use std::collections::HashSet;
use std::sync::Arc;
use std::time::{Duration, Instant};
use futures::future::try_join_all;
use itertools::Itertools as _;
use rand::distributions::WeightedIndex;
use rand::rngs::StdRng;
use rand::{Rng, SeedableRng};
use segment::data_types::order_by::{Direction, OrderBy, OrderValue};
use segment::types::{
ExtendedPointId, Filter, ScoredPoint, WithPayload, WithPayloadInterface, WithVector,
};
use tokio::runtime::Handle;
use tokio::time::error::Elapsed;
use super::LocalShard;
use crate::collection_manager::holders::segment_holder::LockedSegment;
use crate::collection_manager::segments_searcher::SegmentsSearcher;
use crate::common::stopping_guard::StoppingGuard;
use crate::operations::types::{
CollectionError, CollectionResult, QueryScrollRequestInternal, RecordInternal, ScrollOrder,
};
impl LocalShard {
/// Basic parallel batching, it is conveniently used for the universal query API.
pub(super) async fn query_scroll_batch(
&self,
batch: Arc<Vec<QueryScrollRequestInternal>>,
search_runtime_handle: &Handle,
timeout: Duration,
) -> CollectionResult<Vec<Vec<ScoredPoint>>> {
let scrolls = batch
.iter()
.map(|request| self.query_scroll(request, search_runtime_handle, Some(timeout)));
// execute all the scrolls concurrently
let all_scroll_results = try_join_all(scrolls);
tokio::time::timeout(timeout, all_scroll_results)
.await
.map_err(|_| {
log::debug!(
"Query scroll timeout reached: {} seconds",
timeout.as_secs()
);
CollectionError::timeout(timeout.as_secs() as usize, "Query scroll")
})?
}
/// Scroll a single page, to be used for the universal query API only.
async fn query_scroll(
&self,
request: &QueryScrollRequestInternal,
search_runtime_handle: &Handle,
timeout: Option<Duration>,
) -> CollectionResult<Vec<ScoredPoint>> {
let QueryScrollRequestInternal {
limit,
with_vector,
filter,
scroll_order,
with_payload,
} = request;
let limit = *limit;
let offset_id = None;
let point_results = match scroll_order {
ScrollOrder::ById => self
.scroll_by_id(
offset_id,
limit,
with_payload,
with_vector,
filter.as_ref(),
search_runtime_handle,
timeout,
)
.await?
.into_iter()
.map(|record| ScoredPoint {
id: record.id,
version: 0,
score: 0.0,
payload: record.payload,
vector: record.vector,
shard_key: record.shard_key,
order_value: None,
})
.collect(),
ScrollOrder::ByField(order_by) => {
let (records, values) = self
.scroll_by_field(
limit,
with_payload,
with_vector,
filter.as_ref(),
search_runtime_handle,
order_by,
timeout,
)
.await?;
records
.into_iter()
.zip(values)
.map(|(record, value)| ScoredPoint {
id: record.id,
version: 0,
score: 0.0,
payload: record.payload,
vector: record.vector,
shard_key: record.shard_key,
order_value: Some(value),
})
.collect()
}
ScrollOrder::Random => {
let records = self
.scroll_randomly(
limit,
with_payload,
with_vector,
filter.as_ref(),
search_runtime_handle,
timeout,
)
.await?;
records
.into_iter()
.map(|record| ScoredPoint {
id: record.id,
version: 0,
score: 0.0,
payload: record.payload,
vector: record.vector,
shard_key: record.shard_key,
order_value: None,
})
.collect()
}
};
Ok(point_results)
}
#[allow(clippy::too_many_arguments)]
pub async fn scroll_by_id(
&self,
offset: Option<ExtendedPointId>,
limit: usize,
with_payload_interface: &WithPayloadInterface,
with_vector: &WithVector,
filter: Option<&Filter>,
search_runtime_handle: &Handle,
timeout: Option<Duration>,
) -> CollectionResult<Vec<RecordInternal>> {
let start = Instant::now();
let timeout = timeout.unwrap_or(self.shared_storage_config.search_timeout);
let stopping_guard = StoppingGuard::new();
let segments = self.segments.clone();
let (non_appendable, appendable) = segments.read().split_segments();
let read_filtered = |segment: LockedSegment| {
let filter = filter.cloned();
let is_stopped = stopping_guard.get_is_stopped();
search_runtime_handle.spawn_blocking(move || {
segment.get().read().read_filtered(
offset,
Some(limit),
filter.as_ref(),
&is_stopped,
)
})
};
let all_reads = tokio::time::timeout(
timeout,
try_join_all(
non_appendable
.into_iter()
.chain(appendable)
.map(read_filtered),
),
)
.await
.map_err(|_: Elapsed| {
CollectionError::timeout(timeout.as_secs() as usize, "scroll_by_id")
})??;
let point_ids = all_reads
.into_iter()
.flatten()
.sorted()
.dedup()
.take(limit)
.collect_vec();
let with_payload = WithPayload::from(with_payload_interface);
// update timeout
let timeout = timeout.saturating_sub(start.elapsed());
let mut records_map = tokio::time::timeout(
timeout,
SegmentsSearcher::retrieve(
segments,
&point_ids,
&with_payload,
with_vector,
search_runtime_handle,
),
)
.await
.map_err(|_: Elapsed| CollectionError::timeout(timeout.as_secs() as usize, "retrieve"))??;
let ordered_records = point_ids
.iter()
// Use remove to avoid cloning, we take each point ID only once
.filter_map(|point_id| records_map.remove(point_id))
.collect();
Ok(ordered_records)
}
#[allow(clippy::too_many_arguments)]
pub async fn scroll_by_field(
&self,
limit: usize,
with_payload_interface: &WithPayloadInterface,
with_vector: &WithVector,
filter: Option<&Filter>,
search_runtime_handle: &Handle,
order_by: &OrderBy,
timeout: Option<Duration>,
) -> CollectionResult<(Vec<RecordInternal>, Vec<OrderValue>)> {
let start = Instant::now();
let timeout = timeout.unwrap_or(self.shared_storage_config.search_timeout);
let stopping_guard = StoppingGuard::new();
let segments = self.segments.clone();
let (non_appendable, appendable) = segments.read().split_segments();
let read_ordered_filtered = |segment: LockedSegment| {
let is_stopped = stopping_guard.get_is_stopped();
let filter = filter.cloned();
let order_by = order_by.clone();
search_runtime_handle.spawn_blocking(move || {
segment.get().read().read_ordered_filtered(
Some(limit),
filter.as_ref(),
&order_by,
&is_stopped,
)
})
};
let all_reads = tokio::time::timeout(
timeout,
try_join_all(
non_appendable
.into_iter()
.chain(appendable)
.map(read_ordered_filtered),
),
)
.await
.map_err(|_: Elapsed| {
CollectionError::timeout(timeout.as_secs() as usize, "scroll_by_field")
})??;
let all_reads = all_reads.into_iter().collect::<Result<Vec<_>, _>>()?;
let (values, point_ids): (Vec<_>, Vec<_>) = all_reads
.into_iter()
.kmerge_by(|a, b| match order_by.direction() {
Direction::Asc => a <= b,
Direction::Desc => a >= b,
})
.dedup()
.take(limit)
.unzip();
let with_payload = WithPayload::from(with_payload_interface);
// update timeout
let timeout = timeout.saturating_sub(start.elapsed());
// Fetch with the requested vector and payload
let records_map = tokio::time::timeout(
timeout,
SegmentsSearcher::retrieve(
segments,
&point_ids,
&with_payload,
with_vector,
search_runtime_handle,
),
)
.await
.map_err(|_| CollectionError::timeout(timeout.as_secs() as usize, "retrieve"))??;
let ordered_records = point_ids
.iter()
.filter_map(|point_id| records_map.get(point_id).cloned())
.collect();
Ok((ordered_records, values))
}
async fn scroll_randomly(
&self,
limit: usize,
with_payload_interface: &WithPayloadInterface,
with_vector: &WithVector,
filter: Option<&Filter>,
search_runtime_handle: &Handle,
timeout: Option<Duration>,
) -> CollectionResult<Vec<RecordInternal>> {
let start = Instant::now();
let timeout = timeout.unwrap_or(self.shared_storage_config.search_timeout);
let stopping_guard = StoppingGuard::new();
let segments = self.segments.clone();
let (non_appendable, appendable) = segments.read().split_segments();
let read_filtered = |segment: LockedSegment| {
let is_stopped = stopping_guard.get_is_stopped();
let filter = filter.cloned();
search_runtime_handle.spawn_blocking(move || {
let get_segment = segment.get();
let read_segment = get_segment.read();
(
read_segment.available_point_count(),
read_segment.read_random_filtered(limit, filter.as_ref(), &is_stopped),
)
})
};
let all_reads = tokio::time::timeout(
timeout,
try_join_all(
non_appendable
.into_iter()
.chain(appendable)
.map(read_filtered),
),
)
.await
.map_err(|_: Elapsed| {
CollectionError::timeout(timeout.as_secs() as usize, "scroll_randomly")
})??;
let (availability, mut segments_reads): (Vec<_>, Vec<_>) = all_reads.into_iter().unzip();
// Shortcut if all segments are empty
if availability.iter().all(|&count| count == 0) {
return Ok(Vec::new());
}
// Select points in a weighted fashion from each segment, depending on how many points each segment has.
let distribution = WeightedIndex::new(availability).map_err(|err| {
CollectionError::service_error(format!(
"Failed to create weighted index for random scroll: {err:?}"
))
})?;
let mut rng = StdRng::from_entropy();
let mut random_points = HashSet::with_capacity(limit);
// Randomly sample points in two stages
//
// 1. This loop iterates <= LIMIT times, and either breaks early if we
// have enough points, or if some of the segments are exhausted.
//
// 2. If the segments are exhausted, we will fill up the rest of the
// points from other segments. In total, the complexity is guaranteed to
// be O(limit).
while random_points.len() < limit {
let segment_offset = rng.sample(&distribution);
let points = segments_reads.get_mut(segment_offset).unwrap();
if let Some(point) = points.pop() {
random_points.insert(point);
} else {
// It seems that some segments are empty early,
// so distribution does not make sense anymore.
// This is only possible if segments size < limit.
break;
}
}
// If we still need more points, we will get them from the rest of the segments.
// This is a rare case, as it seems we don't have enough points in individual segments.
// Therefore, we can ignore "proper" distribution, as it won't be accurate anyway.
if random_points.len() < limit {
let rest_points = segments_reads.into_iter().flatten();
for point in rest_points {
random_points.insert(point);
if random_points.len() >= limit {
break;
}
}
}
let selected_points: Vec<_> = random_points.into_iter().collect();
let with_payload = WithPayload::from(with_payload_interface);
// update timeout
let timeout = timeout.saturating_sub(start.elapsed());
let records_map = tokio::time::timeout(
timeout,
SegmentsSearcher::retrieve(
segments,
&selected_points,
&with_payload,
with_vector,
search_runtime_handle,
),
)
.await
.map_err(|_: Elapsed| CollectionError::timeout(timeout.as_secs() as usize, "retrieve"))??;
Ok(records_map.into_values().collect())
}
}