Gouzi Mohaled
Ajout du dossier lib
84d2a97
use collection::collection::Collection;
use collection::grouping::group_by::{GroupRequest, SourceRequest};
use collection::operations::point_ops::WriteOrdering;
use collection::operations::types::{RecommendRequestInternal, UpdateStatus};
use collection::operations::CollectionUpdateOperations;
use itertools::Itertools;
use rand::distributions::Uniform;
use rand::rngs::ThreadRng;
use rand::Rng;
use segment::data_types::vectors::DenseVector;
use segment::json_path::JsonPath;
use segment::types::{Filter, Payload, WithPayloadInterface, WithVector};
use serde_json::json;
use crate::common::simple_collection_fixture;
fn rand_dense_vector(rng: &mut ThreadRng, size: usize) -> DenseVector {
rng.sample_iter(Uniform::new(0.4, 0.6)).take(size).collect()
}
mod group_by {
use api::rest::SearchRequestInternal;
use collection::grouping::GroupBy;
use collection::operations::point_ops::{
BatchPersisted, BatchVectorStructPersisted, PointInsertOperationsInternal, PointOperations,
};
use common::counter::hardware_accumulator::HwMeasurementAcc;
use super::*;
struct Resources {
request: GroupRequest,
collection: Collection,
}
async fn setup(docs: u64, chunks: u64) -> Resources {
let mut rng = rand::thread_rng();
let source = SourceRequest::Search(SearchRequestInternal {
vector: vec![0.5, 0.5, 0.5, 0.5].into(),
filter: None,
params: None,
limit: 4,
offset: None,
with_payload: None,
with_vector: None,
score_threshold: None,
});
let request = GroupRequest::with_limit_from_request(source, JsonPath::new("docId"), 3);
let collection_dir = tempfile::Builder::new()
.prefix("collection")
.tempdir()
.unwrap();
let collection = simple_collection_fixture(collection_dir.path(), 1).await;
let batch = BatchPersisted {
ids: (0..docs * chunks).map(|x| x.into()).collect_vec(),
vectors: BatchVectorStructPersisted::Single(
(0..docs * chunks)
.map(|_| rand_dense_vector(&mut rng, 4))
.collect_vec(),
),
payloads: (0..docs)
.flat_map(|x| {
(0..chunks).map(move |_| {
Some(Payload::from(
json!({ "docId": x , "other_stuff": x.to_string() + "foo" }),
))
})
})
.collect_vec()
.into(),
};
let insert_points = CollectionUpdateOperations::PointOperation(
PointOperations::UpsertPoints(PointInsertOperationsInternal::from(batch)),
);
let insert_result = collection
.update_from_client_simple(insert_points, true, WriteOrdering::default())
.await
.expect("insert failed");
assert_eq!(insert_result.status, UpdateStatus::Completed);
Resources {
request,
collection,
}
}
#[tokio::test(flavor = "multi_thread")]
async fn searching() {
let resources = setup(16, 8).await;
let hw_acc = HwMeasurementAcc::new();
let group_by = GroupBy::new(
resources.request.clone(),
&resources.collection,
|_| async { unreachable!() },
&hw_acc,
);
let result = group_by.execute().await;
hw_acc.discard();
assert!(result.is_ok());
let result = result.unwrap();
let group_req = resources.request;
assert_eq!(result.len(), group_req.limit);
assert_eq!(result[0].hits.len(), group_req.group_size);
// is sorted?
let mut last_group_best_score = f32::MAX;
for group in result {
assert!(group.hits[0].score <= last_group_best_score);
last_group_best_score = group.hits[0].score;
let mut last_score = f32::MAX;
for hit in group.hits {
assert!(hit.score <= last_score);
last_score = hit.score;
}
}
}
#[tokio::test(flavor = "multi_thread")]
async fn recommending() {
let resources = setup(16, 8).await;
let request = GroupRequest::with_limit_from_request(
SourceRequest::Recommend(RecommendRequestInternal {
strategy: Default::default(),
filter: None,
params: None,
limit: 4,
offset: None,
with_payload: None,
with_vector: None,
score_threshold: None,
positive: vec![1.into(), 2.into(), 3.into()],
negative: Vec::new(),
using: None,
lookup_from: None,
}),
JsonPath::new("docId"),
2,
);
let hw_acc = HwMeasurementAcc::new();
let group_by = GroupBy::new(
request.clone(),
&resources.collection,
|_| async { unreachable!() },
&hw_acc,
);
let result = group_by.execute().await;
hw_acc.discard();
assert!(result.is_ok());
let result = result.unwrap();
assert_eq!(result.len(), request.limit);
let mut last_group_best_score = f32::MAX;
for group in result {
assert_eq!(group.hits.len(), request.group_size);
// is sorted?
assert!(group.hits[0].score <= last_group_best_score);
last_group_best_score = group.hits[0].score;
let mut last_score = f32::MAX;
for hit in group.hits {
assert!(hit.score <= last_score);
last_score = hit.score;
}
}
}
#[tokio::test(flavor = "multi_thread")]
async fn with_filter() {
let resources = setup(16, 8).await;
let filter: Filter = serde_json::from_value(json!({
"must": [
{
"key": "docId",
"range": {
"gte": 1,
"lte": 2
}
}
]
}))
.unwrap();
let group_by_request = GroupRequest::with_limit_from_request(
SourceRequest::Search(SearchRequestInternal {
vector: vec![0.5, 0.5, 0.5, 0.5].into(),
filter: Some(filter.clone()),
params: None,
limit: 4,
offset: None,
with_payload: None,
with_vector: None,
score_threshold: None,
}),
JsonPath::new("docId"),
3,
);
let hw_acc = HwMeasurementAcc::new();
let group_by = GroupBy::new(
group_by_request,
&resources.collection,
|_| async { unreachable!() },
&hw_acc,
);
let result = group_by.execute().await;
hw_acc.discard();
assert!(result.is_ok());
let result = result.unwrap();
assert_eq!(result.len(), 2);
}
#[tokio::test(flavor = "multi_thread")]
async fn with_payload_and_vectors() {
let resources = setup(16, 8).await;
let group_by_request = GroupRequest::with_limit_from_request(
SourceRequest::Search(SearchRequestInternal {
vector: vec![0.5, 0.5, 0.5, 0.5].into(),
filter: None,
params: None,
limit: 4,
offset: None,
with_payload: Some(WithPayloadInterface::Bool(true)),
with_vector: Some(WithVector::Bool(true)),
score_threshold: None,
}),
JsonPath::new("docId"),
3,
);
let hw_acc = HwMeasurementAcc::new();
let group_by = GroupBy::new(
group_by_request.clone(),
&resources.collection,
|_| async { unreachable!() },
&hw_acc,
);
let result = group_by.execute().await;
hw_acc.discard();
assert!(result.is_ok());
let result = result.unwrap();
assert_eq!(result.len(), 4);
for group in result {
assert_eq!(group.hits.len(), group_by_request.group_size);
assert!(group.hits[0].payload.is_some());
assert!(group.hits[0].vector.is_some());
}
}
#[tokio::test(flavor = "multi_thread")]
async fn group_by_string_field() {
let Resources { collection, .. } = setup(16, 8).await;
let group_by_request = GroupRequest::with_limit_from_request(
SourceRequest::Search(SearchRequestInternal {
vector: vec![0.5, 0.5, 0.5, 0.5].into(),
filter: None,
params: None,
limit: 4,
offset: None,
with_payload: Some(WithPayloadInterface::Bool(true)),
with_vector: Some(WithVector::Bool(true)),
score_threshold: None,
}),
JsonPath::new("other_stuff"),
3,
);
let hw_acc = HwMeasurementAcc::new();
let group_by = GroupBy::new(
group_by_request.clone(),
&collection,
|_| async { unreachable!() },
&hw_acc,
);
let result = group_by.execute().await;
hw_acc.discard();
assert!(result.is_ok());
let result = result.unwrap();
assert_eq!(result.len(), 4);
for group in result {
assert_eq!(group.hits.len(), group_by_request.group_size);
}
}
#[tokio::test(flavor = "multi_thread")]
async fn zero_group_size() {
let Resources { collection, .. } = setup(16, 8).await;
let group_by_request = GroupRequest::with_limit_from_request(
SourceRequest::Search(SearchRequestInternal {
vector: vec![0.5, 0.5, 0.5, 0.5].into(),
filter: None,
params: None,
limit: 4,
offset: None,
with_payload: None,
with_vector: None,
score_threshold: None,
}),
JsonPath::new("docId"),
0,
);
let hw_acc = HwMeasurementAcc::new();
let group_by = GroupBy::new(
group_by_request.clone(),
&collection,
|_| async { unreachable!() },
&hw_acc,
);
let result = group_by.execute().await;
hw_acc.discard();
assert!(result.is_ok());
let result = result.unwrap();
assert_eq!(result.len(), 0);
}
#[tokio::test(flavor = "multi_thread")]
async fn zero_limit_groups() {
let Resources { collection, .. } = setup(16, 8).await;
let group_by_request = GroupRequest::with_limit_from_request(
SourceRequest::Search(SearchRequestInternal {
vector: vec![0.5, 0.5, 0.5, 0.5].into(),
filter: None,
params: None,
limit: 0,
offset: None,
with_payload: None,
with_vector: None,
score_threshold: None,
}),
JsonPath::new("docId"),
3,
);
let hw_acc = HwMeasurementAcc::new();
let group_by = GroupBy::new(
group_by_request.clone(),
&collection,
|_| async { unreachable!() },
&hw_acc,
);
let result = group_by.execute().await;
hw_acc.discard();
assert!(result.is_ok());
let result = result.unwrap();
assert_eq!(result.len(), 0);
}
#[tokio::test(flavor = "multi_thread")]
async fn big_limit_groups() {
let Resources { collection, .. } = setup(1000, 5).await;
let group_by_request = GroupRequest::with_limit_from_request(
SourceRequest::Search(SearchRequestInternal {
vector: vec![0.5, 0.5, 0.5, 0.5].into(),
filter: None,
params: None,
limit: 500,
offset: None,
with_payload: None,
with_vector: None,
score_threshold: None,
}),
JsonPath::new("docId"),
3,
);
let hw_acc = HwMeasurementAcc::new();
let group_by = GroupBy::new(
group_by_request.clone(),
&collection,
|_| async { unreachable!() },
&hw_acc,
);
let result = group_by.execute().await;
hw_acc.discard();
assert!(result.is_ok());
let result = result.unwrap();
assert_eq!(result.len(), group_by_request.limit);
for group in result {
assert_eq!(group.hits.len(), group_by_request.group_size);
}
}
#[tokio::test(flavor = "multi_thread")]
async fn big_group_size_groups() {
let Resources { collection, .. } = setup(10, 500).await;
let group_by_request = GroupRequest::with_limit_from_request(
SourceRequest::Search(SearchRequestInternal {
vector: vec![0.5, 0.5, 0.5, 0.5].into(),
filter: None,
params: None,
limit: 3,
offset: None,
with_payload: None,
with_vector: None,
score_threshold: None,
}),
JsonPath::new("docId"),
400,
);
let hw_acc = HwMeasurementAcc::new();
let group_by = GroupBy::new(
group_by_request.clone(),
&collection,
|_| async { unreachable!() },
&hw_acc,
);
let result = group_by.execute().await;
hw_acc.discard();
assert!(result.is_ok());
let result = result.unwrap();
assert_eq!(result.len(), group_by_request.limit);
for group in result {
assert_eq!(group.hits.len(), group_by_request.group_size);
}
}
}
/// Tests out the different features working together. The individual features are already tested in other places.
mod group_by_builder {
use api::rest::SearchRequestInternal;
use collection::grouping::GroupBy;
use collection::lookup::types::PseudoId;
use collection::lookup::WithLookup;
use collection::operations::point_ops::{
BatchPersisted, BatchVectorStructPersisted, PointInsertOperationsInternal, PointOperations,
};
use common::counter::hardware_accumulator::HwMeasurementAcc;
use segment::json_path::JsonPath;
use tokio::sync::RwLock;
use super::*;
const BODY_TEXT: &str = "lorem ipsum dolor sit amet";
struct Resources {
request: GroupRequest,
lookup_collection: RwLock<Collection>,
collection: Collection,
}
/// Sets up two collections: one for chunks and one for docs.
async fn setup(docs: u64, chunks_per_doc: u64) -> Resources {
let mut rng = rand::thread_rng();
let source_request = SourceRequest::Search(SearchRequestInternal {
vector: vec![0.5, 0.5, 0.5, 0.5].into(),
filter: None,
params: None,
limit: 4,
offset: None,
with_payload: None,
with_vector: None,
score_threshold: None,
});
let request =
GroupRequest::with_limit_from_request(source_request, JsonPath::new("docId"), 3);
let collection_dir = tempfile::Builder::new().prefix("chunks").tempdir().unwrap();
let collection = simple_collection_fixture(collection_dir.path(), 1).await;
// insert chunk points
{
let batch = BatchPersisted {
ids: (0..docs * chunks_per_doc).map(|x| x.into()).collect_vec(),
vectors: BatchVectorStructPersisted::Single(
(0..docs * chunks_per_doc)
.map(|_| rand_dense_vector(&mut rng, 4))
.collect_vec(),
),
payloads: (0..docs)
.flat_map(|x| {
(0..chunks_per_doc).map(move |_| Some(Payload::from(json!({ "docId": x }))))
})
.collect_vec()
.into(),
};
let insert_points = CollectionUpdateOperations::PointOperation(
PointOperations::UpsertPoints(PointInsertOperationsInternal::from(batch)),
);
let insert_result = collection
.update_from_client_simple(insert_points, true, WriteOrdering::default())
.await
.expect("insert failed");
assert_eq!(insert_result.status, UpdateStatus::Completed);
}
let lookup_dir = tempfile::Builder::new().prefix("lookup").tempdir().unwrap();
let lookup_collection = simple_collection_fixture(lookup_dir.path(), 1).await;
// insert doc points
{
let batch = BatchPersisted {
ids: (0..docs).map(|x| x.into()).collect_vec(),
vectors: BatchVectorStructPersisted::Single(
(0..docs)
.map(|_| rand_dense_vector(&mut rng, 4))
.collect_vec(),
),
payloads: (0..docs)
.map(|x| {
Some(Payload::from(
json!({ "docId": x, "body": format!("{x} {BODY_TEXT}") }),
))
})
.collect_vec()
.into(),
};
let insert_points = CollectionUpdateOperations::PointOperation(
PointOperations::UpsertPoints(PointInsertOperationsInternal::from(batch)),
);
let insert_result = lookup_collection
.update_from_client_simple(insert_points, true, WriteOrdering::default())
.await
.expect("insert failed");
assert_eq!(insert_result.status, UpdateStatus::Completed);
}
let lookup_collection = RwLock::new(lookup_collection);
Resources {
request,
lookup_collection,
collection,
}
}
#[tokio::test(flavor = "multi_thread")]
async fn only_group_by() {
let Resources {
request,
collection,
..
} = setup(16, 8).await;
let collection_by_name = |_: String| async { unreachable!() };
let hw_acc = HwMeasurementAcc::new();
let result = GroupBy::new(request.clone(), &collection, collection_by_name, &hw_acc)
.execute()
.await;
assert!(result.is_ok());
hw_acc.discard();
let result = result.unwrap();
// minimal assertion
assert_eq!(result.len(), request.limit);
for group in result {
assert_eq!(group.hits.len(), request.group_size);
assert!(group.lookup.is_none());
}
}
#[tokio::test(flavor = "multi_thread")]
async fn group_by_with_lookup() {
let Resources {
mut request,
collection,
lookup_collection,
..
} = setup(16, 8).await;
request.with_lookup = Some(WithLookup {
collection_name: "test".to_string(),
with_payload: Some(true.into()),
with_vectors: Some(true.into()),
});
let collection_by_name = |_: String| async { Some(lookup_collection.read().await) };
let hw_acc = HwMeasurementAcc::new();
let result = GroupBy::new(request.clone(), &collection, collection_by_name, &hw_acc)
.execute()
.await;
assert!(result.is_ok());
hw_acc.discard();
let result = result.unwrap();
assert_eq!(result.len(), request.limit);
for group in result {
assert_eq!(group.hits.len(), request.group_size);
let lookup = group.lookup.expect("lookup not found");
assert_eq!(PseudoId::from(group.id), PseudoId::from(lookup.id));
let payload = lookup.payload.unwrap();
let body = payload.0.get("body").unwrap().as_str().unwrap();
assert_eq!(body, &format!("{} {BODY_TEXT}", lookup.id));
}
}
}