Gouzi Mohaled
Ajout du dossier lib
84d2a97
use std::collections::{HashMap, HashSet};
use std::fs::File;
use api::rest::{OrderByInterface, SearchRequestInternal};
use collection::operations::payload_ops::{PayloadOps, SetPayloadOp};
use collection::operations::point_ops::{
BatchPersisted, BatchVectorStructPersisted, PointInsertOperationsInternal, PointOperations,
PointStructPersisted, VectorStructPersisted, WriteOrdering,
};
use collection::operations::shard_selector_internal::ShardSelectorInternal;
use collection::operations::types::{
CountRequestInternal, PointRequestInternal, RecommendRequestInternal, ScrollRequestInternal,
UpdateStatus,
};
use collection::operations::CollectionUpdateOperations;
use collection::recommendations::recommend_by;
use collection::shards::replica_set::{ReplicaSetState, ReplicaState};
use common::counter::hardware_accumulator::HwMeasurementAcc;
use itertools::Itertools;
use segment::data_types::order_by::{Direction, OrderBy};
use segment::data_types::vectors::VectorStructInternal;
use segment::types::{
Condition, ExtendedPointId, FieldCondition, Filter, HasIdCondition, Payload,
PayloadFieldSchema, PayloadSchemaType, PointIdType, WithPayloadInterface,
};
use serde_json::Map;
use tempfile::Builder;
use crate::common::{load_local_collection, simple_collection_fixture, N_SHARDS};
#[tokio::test(flavor = "multi_thread")]
async fn test_collection_updater() {
test_collection_updater_with_shards(1).await;
test_collection_updater_with_shards(N_SHARDS).await;
}
async fn test_collection_updater_with_shards(shard_number: u32) {
let collection_dir = Builder::new().prefix("collection").tempdir().unwrap();
let collection = simple_collection_fixture(collection_dir.path(), shard_number).await;
let batch = BatchPersisted {
ids: vec![0, 1, 2, 3, 4]
.into_iter()
.map(|x| x.into())
.collect_vec(),
vectors: BatchVectorStructPersisted::Single(vec![
vec![1.0, 0.0, 1.0, 1.0],
vec![1.0, 0.0, 1.0, 0.0],
vec![1.0, 1.0, 1.0, 1.0],
vec![1.0, 1.0, 0.0, 1.0],
vec![1.0, 0.0, 0.0, 0.0],
]),
payloads: None,
};
let insert_points = CollectionUpdateOperations::PointOperation(PointOperations::UpsertPoints(
PointInsertOperationsInternal::from(batch),
));
let insert_result = collection
.update_from_client_simple(insert_points, true, WriteOrdering::default())
.await;
match insert_result {
Ok(res) => {
assert_eq!(res.status, UpdateStatus::Completed)
}
Err(err) => panic!("operation failed: {err:?}"),
}
let search_request = SearchRequestInternal {
vector: vec![1.0, 1.0, 1.0, 1.0].into(),
with_payload: None,
with_vector: None,
filter: None,
params: None,
limit: 3,
offset: None,
score_threshold: None,
};
let hw_acc = HwMeasurementAcc::new();
let search_res = collection
.search(
search_request.into(),
None,
&ShardSelectorInternal::All,
None,
&hw_acc,
)
.await;
hw_acc.discard();
match search_res {
Ok(res) => {
assert_eq!(res.len(), 3);
assert_eq!(res[0].id, 2.into());
assert!(res[0].payload.is_none());
}
Err(err) => panic!("search failed: {err:?}"),
}
}
#[tokio::test(flavor = "multi_thread")]
async fn test_collection_search_with_payload_and_vector() {
test_collection_search_with_payload_and_vector_with_shards(1).await;
test_collection_search_with_payload_and_vector_with_shards(N_SHARDS).await;
}
async fn test_collection_search_with_payload_and_vector_with_shards(shard_number: u32) {
let collection_dir = Builder::new().prefix("collection").tempdir().unwrap();
let collection = simple_collection_fixture(collection_dir.path(), shard_number).await;
let batch = BatchPersisted {
ids: vec![0.into(), 1.into()],
vectors: BatchVectorStructPersisted::Single(vec![
vec![1.0, 0.0, 1.0, 1.0],
vec![1.0, 0.0, 1.0, 0.0],
]),
payloads: serde_json::from_str(
r#"[{ "k": { "type": "keyword", "value": "v1" } }, { "k": "v2" , "v": "v3"}]"#,
)
.unwrap(),
};
let insert_points = CollectionUpdateOperations::PointOperation(PointOperations::UpsertPoints(
PointInsertOperationsInternal::from(batch),
));
let insert_result = collection
.update_from_client_simple(insert_points, true, WriteOrdering::default())
.await;
match insert_result {
Ok(res) => {
assert_eq!(res.status, UpdateStatus::Completed)
}
Err(err) => panic!("operation failed: {err:?}"),
}
let search_request = SearchRequestInternal {
vector: vec![1.0, 0.0, 1.0, 1.0].into(),
with_payload: Some(WithPayloadInterface::Bool(true)),
with_vector: Some(true.into()),
filter: None,
params: None,
limit: 3,
offset: None,
score_threshold: None,
};
let hw_acc = HwMeasurementAcc::new();
let search_res = collection
.search(
search_request.into(),
None,
&ShardSelectorInternal::All,
None,
&hw_acc,
)
.await;
hw_acc.discard();
match search_res {
Ok(res) => {
assert_eq!(res.len(), 2);
assert_eq!(res[0].id, 0.into());
assert_eq!(res[0].payload.as_ref().unwrap().len(), 1);
let vec = vec![1.0, 0.0, 1.0, 1.0];
match &res[0].vector {
Some(VectorStructInternal::Single(v)) => assert_eq!(v.clone(), vec),
_ => panic!("vector is not returned"),
}
}
Err(err) => panic!("search failed: {err:?}"),
}
let count_request = CountRequestInternal {
filter: Some(Filter::new_must(Condition::Field(
FieldCondition::new_match(
"k".parse().unwrap(),
serde_json::from_str(r#"{ "value": "v2" }"#).unwrap(),
),
))),
exact: true,
};
let hw_acc = HwMeasurementAcc::new();
let count_res = collection
.count(
count_request,
None,
&ShardSelectorInternal::All,
None,
&hw_acc,
)
.await
.unwrap();
assert_eq!(count_res.count, 1);
hw_acc.discard();
}
// FIXME: does not work
#[tokio::test(flavor = "multi_thread")]
async fn test_collection_loading() {
test_collection_loading_with_shards(1).await;
test_collection_loading_with_shards(N_SHARDS).await;
}
async fn test_collection_loading_with_shards(shard_number: u32) {
let collection_dir = Builder::new().prefix("collection").tempdir().unwrap();
{
let collection = simple_collection_fixture(collection_dir.path(), shard_number).await;
let batch = BatchPersisted {
ids: vec![0, 1, 2, 3, 4]
.into_iter()
.map(|x| x.into())
.collect_vec(),
vectors: BatchVectorStructPersisted::Single(vec![
vec![1.0, 0.0, 1.0, 1.0],
vec![1.0, 0.0, 1.0, 0.0],
vec![1.0, 1.0, 1.0, 1.0],
vec![1.0, 1.0, 0.0, 1.0],
vec![1.0, 0.0, 0.0, 0.0],
]),
payloads: None,
};
let insert_points = CollectionUpdateOperations::PointOperation(
PointOperations::UpsertPoints(PointInsertOperationsInternal::from(batch)),
);
collection
.update_from_client_simple(insert_points, true, WriteOrdering::default())
.await
.unwrap();
let payload: Payload = serde_json::from_str(r#"{"color":"red"}"#).unwrap();
let assign_payload =
CollectionUpdateOperations::PayloadOperation(PayloadOps::SetPayload(SetPayloadOp {
payload,
points: Some(vec![2.into(), 3.into()]),
filter: None,
key: None,
}));
collection
.update_from_client_simple(assign_payload, true, WriteOrdering::default())
.await
.unwrap();
}
let collection_path = collection_dir.path();
let loaded_collection = load_local_collection(
"test".to_string(),
collection_path,
&collection_path.join("snapshots"),
)
.await;
let request = PointRequestInternal {
ids: vec![1.into(), 2.into()],
with_payload: Some(WithPayloadInterface::Bool(true)),
with_vector: true.into(),
};
let retrieved = loaded_collection
.retrieve(request, None, &ShardSelectorInternal::All, None)
.await
.unwrap();
assert_eq!(retrieved.len(), 2);
for record in retrieved {
if record.id == 2.into() {
let non_empty_payload = record.payload.unwrap();
assert_eq!(non_empty_payload.len(), 1)
}
}
println!("Function end");
}
#[test]
fn test_deserialization() {
let batch = BatchPersisted {
ids: vec![0.into(), 1.into()],
vectors: BatchVectorStructPersisted::Single(vec![
vec![1.0, 0.0, 1.0, 1.0],
vec![1.0, 0.0, 1.0, 0.0],
]),
payloads: None,
};
let insert_points = CollectionUpdateOperations::PointOperation(PointOperations::UpsertPoints(
PointInsertOperationsInternal::from(batch),
));
let json_str = serde_json::to_string_pretty(&insert_points).unwrap();
let _read_obj: CollectionUpdateOperations = serde_json::from_str(&json_str).unwrap();
let crob_bytes = rmp_serde::to_vec(&insert_points).unwrap();
let _read_obj2: CollectionUpdateOperations = rmp_serde::from_slice(&crob_bytes).unwrap();
}
#[test]
fn test_deserialization2() {
let points = vec![
PointStructPersisted {
id: 0.into(),
vector: VectorStructPersisted::from(vec![1.0, 0.0, 1.0, 1.0]),
payload: None,
},
PointStructPersisted {
id: 1.into(),
vector: VectorStructPersisted::from(vec![1.0, 0.0, 1.0, 0.0]),
payload: None,
},
];
let insert_points = CollectionUpdateOperations::PointOperation(PointOperations::UpsertPoints(
PointInsertOperationsInternal::from(points),
));
let json_str = serde_json::to_string_pretty(&insert_points).unwrap();
let _read_obj: CollectionUpdateOperations = serde_json::from_str(&json_str).unwrap();
let raw_bytes = rmp_serde::to_vec(&insert_points).unwrap();
let _read_obj2: CollectionUpdateOperations = rmp_serde::from_slice(&raw_bytes).unwrap();
}
// Request to find points sent to all shards but they might not have a particular id, so they will return an error
#[tokio::test(flavor = "multi_thread")]
async fn test_recommendation_api() {
test_recommendation_api_with_shards(1).await;
test_recommendation_api_with_shards(N_SHARDS).await;
}
async fn test_recommendation_api_with_shards(shard_number: u32) {
let collection_dir = Builder::new().prefix("collection").tempdir().unwrap();
let collection = simple_collection_fixture(collection_dir.path(), shard_number).await;
let batch = BatchPersisted {
ids: vec![0, 1, 2, 3, 4, 5, 6, 7, 8]
.into_iter()
.map(|x| x.into())
.collect_vec(),
vectors: BatchVectorStructPersisted::Single(vec![
vec![0.0, 0.0, 1.0, 1.0],
vec![1.0, 0.0, 0.0, 0.0],
vec![1.0, 0.0, 0.0, 0.0],
vec![0.0, 1.0, 0.0, 0.0],
vec![0.0, 1.0, 0.0, 0.0],
vec![0.0, 0.0, 1.0, 0.0],
vec![0.0, 0.0, 1.0, 0.0],
vec![0.0, 0.0, 0.0, 1.0],
vec![0.0, 0.0, 0.0, 1.0],
]),
payloads: None,
};
let insert_points = CollectionUpdateOperations::PointOperation(PointOperations::UpsertPoints(
PointInsertOperationsInternal::from(batch),
));
let hw_acc = HwMeasurementAcc::new();
collection
.update_from_client_simple(insert_points, true, WriteOrdering::default())
.await
.unwrap();
let result = recommend_by(
RecommendRequestInternal {
positive: vec![0.into()],
negative: vec![8.into()],
limit: 5,
..Default::default()
},
&collection,
|_name| async { unreachable!("Should not be called in this test") },
None,
ShardSelectorInternal::All,
None,
&hw_acc,
)
.await
.unwrap();
assert!(!result.is_empty());
let top1 = &result[0];
hw_acc.discard();
assert!(top1.id == 5.into() || top1.id == 6.into());
}
#[tokio::test(flavor = "multi_thread")]
async fn test_read_api() {
test_read_api_with_shards(1).await;
test_read_api_with_shards(N_SHARDS).await;
}
async fn test_read_api_with_shards(shard_number: u32) {
let collection_dir = Builder::new().prefix("collection").tempdir().unwrap();
let collection = simple_collection_fixture(collection_dir.path(), shard_number).await;
let batch = BatchPersisted {
ids: vec![0, 1, 2, 3, 4, 5, 6, 7, 8]
.into_iter()
.map(|x| x.into())
.collect_vec(),
vectors: BatchVectorStructPersisted::Single(vec![
vec![0.0, 0.0, 1.0, 1.0],
vec![1.0, 0.0, 0.0, 0.0],
vec![1.0, 0.0, 0.0, 0.0],
vec![0.0, 1.0, 0.0, 0.0],
vec![0.0, 1.0, 0.0, 0.0],
vec![0.0, 0.0, 1.0, 0.0],
vec![0.0, 0.0, 1.0, 0.0],
vec![0.0, 0.0, 0.0, 1.0],
vec![0.0, 0.0, 0.0, 1.0],
]),
payloads: None,
};
let insert_points = CollectionUpdateOperations::PointOperation(PointOperations::UpsertPoints(
PointInsertOperationsInternal::from(batch),
));
collection
.update_from_client_simple(insert_points, true, WriteOrdering::default())
.await
.unwrap();
let result = collection
.scroll_by(
ScrollRequestInternal {
offset: None,
limit: Some(2),
filter: None,
with_payload: Some(WithPayloadInterface::Bool(true)),
with_vector: false.into(),
order_by: None,
},
None,
&ShardSelectorInternal::All,
None,
)
.await
.unwrap();
assert_eq!(result.next_page_offset, Some(2.into()));
assert_eq!(result.points.len(), 2);
}
#[tokio::test(flavor = "multi_thread")]
async fn test_ordered_read_api() {
test_ordered_scroll_api_with_shards(1).await;
test_ordered_scroll_api_with_shards(N_SHARDS).await;
}
async fn test_ordered_scroll_api_with_shards(shard_number: u32) {
let collection_dir = Builder::new().prefix("collection").tempdir().unwrap();
let collection = simple_collection_fixture(collection_dir.path(), shard_number).await;
const PRICE_FLOAT_KEY: &str = "price_float";
const PRICE_INT_KEY: &str = "price_int";
const MULTI_VALUE_KEY: &str = "multi_value";
let get_payload = |value: f64| -> Option<Payload> {
let mut payload_map = Map::new();
payload_map.insert(PRICE_FLOAT_KEY.to_string(), (value).into());
payload_map.insert(PRICE_INT_KEY.to_string(), (value as i64).into());
payload_map.insert(
MULTI_VALUE_KEY.to_string(),
vec![value, value + 20.0].into(),
);
Some(Payload(payload_map))
};
let payloads: Vec<Option<Payload>> = vec![
get_payload(11.0),
get_payload(10.0),
get_payload(9.0),
get_payload(8.0),
get_payload(7.0),
get_payload(6.0),
get_payload(5.0),
get_payload(5.0),
get_payload(5.0),
get_payload(5.0),
get_payload(4.0),
get_payload(3.0),
get_payload(2.0),
get_payload(1.0),
];
let batch = BatchPersisted {
ids: vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]
.into_iter()
.map(|x| x.into())
.collect_vec(),
vectors: BatchVectorStructPersisted::Single(vec![
vec![0.0, 0.0, 1.0, 1.0],
vec![1.0, 0.0, 0.0, 0.0],
vec![1.0, 0.0, 0.0, 0.0],
vec![0.0, 1.0, 0.0, 0.0],
vec![0.0, 1.0, 0.0, 0.0],
vec![0.0, 0.0, 1.0, 0.0],
vec![0.0, 0.0, 1.0, 0.0],
vec![0.0, 0.0, 0.0, 1.0],
vec![0.0, 0.0, 0.0, 1.0],
vec![0.0, 1.0, 1.0, 1.0],
vec![0.0, 1.0, 1.0, 1.0],
vec![0.0, 1.0, 1.0, 1.0],
vec![0.0, 1.0, 1.0, 1.0],
vec![1.0, 1.0, 1.0, 1.0],
]),
payloads: Some(payloads),
};
let insert_points = CollectionUpdateOperations::PointOperation(PointOperations::UpsertPoints(
PointInsertOperationsInternal::from(batch),
));
collection
.update_from_client_simple(insert_points, true, WriteOrdering::default())
.await
.unwrap();
collection
.create_payload_index_with_wait(
PRICE_FLOAT_KEY.parse().unwrap(),
PayloadFieldSchema::FieldType(PayloadSchemaType::Float),
true,
)
.await
.unwrap();
collection
.create_payload_index_with_wait(
PRICE_INT_KEY.parse().unwrap(),
PayloadFieldSchema::FieldType(PayloadSchemaType::Integer),
true,
)
.await
.unwrap();
collection
.create_payload_index_with_wait(
MULTI_VALUE_KEY.parse().unwrap(),
PayloadFieldSchema::FieldType(PayloadSchemaType::Float),
true,
)
.await
.unwrap();
///////// Test single-valued fields ///////////
for key in [PRICE_FLOAT_KEY, PRICE_INT_KEY] {
let result_asc = collection
.scroll_by(
ScrollRequestInternal {
offset: None,
limit: Some(3),
filter: None,
with_payload: Some(WithPayloadInterface::Bool(true)),
with_vector: false.into(),
order_by: Some(OrderByInterface::Struct(OrderBy {
key: key.parse().unwrap(),
direction: Some(Direction::Asc),
start_from: None,
})),
},
None,
&ShardSelectorInternal::All,
None,
)
.await
.unwrap();
assert_eq!(result_asc.points.len(), 3);
assert_eq!(result_asc.next_page_offset, None);
assert!(result_asc.points.iter().tuple_windows().all(|(a, b)| {
let a = a.payload.as_ref().unwrap();
let b = b.payload.as_ref().unwrap();
let a = a.0.get(key).unwrap().as_f64();
let b = b.0.get(key).unwrap().as_f64();
a <= b
}));
let result_desc = collection
.scroll_by(
ScrollRequestInternal {
offset: None,
limit: Some(5),
filter: None,
with_payload: Some(WithPayloadInterface::Bool(true)),
with_vector: false.into(),
order_by: Some(OrderByInterface::Struct(OrderBy {
key: key.parse().unwrap(),
direction: Some(Direction::Desc),
start_from: None,
})),
},
None,
&ShardSelectorInternal::All,
None,
)
.await
.unwrap();
assert_eq!(result_desc.points.len(), 5);
assert_eq!(result_desc.next_page_offset, None);
assert!(
result_desc.points.iter().tuple_windows().all(|(a, b)| {
let a = a.payload.as_ref().unwrap();
let b = b.payload.as_ref().unwrap();
let a = a.0.get(key).unwrap().as_f64();
let b = b.0.get(key).unwrap().as_f64();
a >= b
}),
"got: {:#?}",
result_desc.points
);
let asc_already_seen: HashSet<_> = result_asc.points.iter().map(|x| x.id).collect();
dbg!(&asc_already_seen);
let asc_second_page = collection
.scroll_by(
ScrollRequestInternal {
offset: None,
limit: Some(5),
filter: Some(Filter::new_must_not(Condition::HasId(
HasIdCondition::from(asc_already_seen),
))),
with_payload: Some(WithPayloadInterface::Bool(true)),
with_vector: false.into(),
order_by: Some(OrderByInterface::Struct(OrderBy {
key: key.parse().unwrap(),
direction: Some(Direction::Asc),
start_from: None,
})),
},
None,
&ShardSelectorInternal::All,
None,
)
.await
.unwrap();
let asc_second_page_points = asc_second_page
.points
.iter()
.map(|x| x.id)
.collect::<HashSet<_>>();
let valid_asc_second_page_points = [10, 9, 8, 7, 6]
.into_iter()
.map(|x| x.into())
.collect::<HashSet<ExtendedPointId>>();
assert_eq!(asc_second_page.points.len(), 5);
assert!(asc_second_page_points.is_subset(&valid_asc_second_page_points));
let desc_already_seen: HashSet<_> = result_desc.points.iter().map(|x| x.id).collect();
dbg!(&desc_already_seen);
let desc_second_page = collection
.scroll_by(
ScrollRequestInternal {
offset: None,
limit: Some(4),
filter: Some(Filter::new_must_not(Condition::HasId(
HasIdCondition::from(desc_already_seen),
))),
with_payload: Some(WithPayloadInterface::Bool(true)),
with_vector: false.into(),
order_by: Some(OrderByInterface::Struct(OrderBy {
key: key.parse().unwrap(),
direction: Some(Direction::Desc),
start_from: None,
})),
},
None,
&ShardSelectorInternal::All,
None,
)
.await
.unwrap();
let desc_second_page_points = desc_second_page
.points
.iter()
.map(|x| x.id)
.collect::<HashSet<_>>();
let valid_desc_second_page_points = [5, 6, 7, 8, 9]
.into_iter()
.map(|x| x.into())
.collect::<HashSet<ExtendedPointId>>();
assert_eq!(desc_second_page.points.len(), 4);
assert!(
desc_second_page_points.is_subset(&valid_desc_second_page_points),
"expected: {valid_desc_second_page_points:?}, got: {desc_second_page_points:?}"
);
}
///////// Test multi-valued field ///////////
let result_multi = collection
.scroll_by(
ScrollRequestInternal {
offset: None,
limit: Some(100),
filter: None,
with_payload: Some(WithPayloadInterface::Bool(true)),
with_vector: false.into(),
order_by: Some(OrderByInterface::Key(MULTI_VALUE_KEY.parse().unwrap())),
},
None,
&ShardSelectorInternal::All,
None,
)
.await
.unwrap();
assert!(result_multi
.points
.iter()
.fold(HashMap::<PointIdType, usize, _>::new(), |mut acc, point| {
acc.entry(point.id)
.and_modify(|x| {
*x += 1;
})
.or_insert(1);
acc
})
.values()
.all(|&x| x == 2));
}
#[tokio::test(flavor = "multi_thread")]
async fn test_collection_delete_points_by_filter() {
test_collection_delete_points_by_filter_with_shards(1).await;
test_collection_delete_points_by_filter_with_shards(N_SHARDS).await;
}
async fn test_collection_delete_points_by_filter_with_shards(shard_number: u32) {
let collection_dir = Builder::new().prefix("collection").tempdir().unwrap();
let collection = simple_collection_fixture(collection_dir.path(), shard_number).await;
let batch = BatchPersisted {
ids: vec![0, 1, 2, 3, 4]
.into_iter()
.map(|x| x.into())
.collect_vec(),
vectors: BatchVectorStructPersisted::Single(vec![
vec![1.0, 0.0, 1.0, 1.0],
vec![1.0, 0.0, 1.0, 0.0],
vec![1.0, 1.0, 1.0, 1.0],
vec![1.0, 1.0, 0.0, 1.0],
vec![1.0, 0.0, 0.0, 0.0],
]),
payloads: None,
};
let insert_points = CollectionUpdateOperations::PointOperation(PointOperations::UpsertPoints(
PointInsertOperationsInternal::from(batch),
));
let insert_result = collection
.update_from_client_simple(insert_points, true, WriteOrdering::default())
.await;
match insert_result {
Ok(res) => {
assert_eq!(res.status, UpdateStatus::Completed)
}
Err(err) => panic!("operation failed: {err:?}"),
}
// delete points with id (0, 3)
let to_be_deleted: HashSet<PointIdType> = vec![0.into(), 3.into()].into_iter().collect();
let delete_filter =
segment::types::Filter::new_must(Condition::HasId(HasIdCondition::from(to_be_deleted)));
let delete_points = CollectionUpdateOperations::PointOperation(
PointOperations::DeletePointsByFilter(delete_filter),
);
let delete_result = collection
.update_from_client_simple(delete_points, true, WriteOrdering::default())
.await;
match delete_result {
Ok(res) => {
assert_eq!(res.status, UpdateStatus::Completed)
}
Err(err) => panic!("operation failed: {err:?}"),
}
let result = collection
.scroll_by(
ScrollRequestInternal {
offset: None,
limit: Some(10),
filter: None,
with_payload: Some(WithPayloadInterface::Bool(false)),
with_vector: false.into(),
order_by: None,
},
None,
&ShardSelectorInternal::All,
None,
)
.await
.unwrap();
// check if we only have 3 out of 5 points left and that the point id were really deleted
assert_eq!(result.points.len(), 3);
assert_eq!(result.points.first().unwrap().id, 1.into());
assert_eq!(result.points.get(1).unwrap().id, 2.into());
assert_eq!(result.points.get(2).unwrap().id, 4.into());
}
#[tokio::test(flavor = "multi_thread")]
async fn test_collection_local_load_initializing_not_stuck() {
let collection_dir = Builder::new().prefix("collection").tempdir().unwrap();
// Create and unload collection
simple_collection_fixture(collection_dir.path(), 1).await;
// Modify replica state file on disk, set state to Initializing
// This is to simulate a situation where a collection was not fully created, we cannot create
// this situation through our collection interface
{
let replica_state_path = collection_dir.path().join("0/replica_state.json");
let replica_state_file = File::open(&replica_state_path).unwrap();
let mut replica_set_state: ReplicaSetState =
serde_json::from_reader(replica_state_file).unwrap();
for peer_id in replica_set_state.peers().into_keys() {
replica_set_state.set_peer_state(peer_id, ReplicaState::Initializing);
}
let replica_state_file = File::create(&replica_state_path).unwrap();
serde_json::to_writer(replica_state_file, &replica_set_state).unwrap();
}
// Reload collection
let collection_path = collection_dir.path();
let loaded_collection = load_local_collection(
"test".to_string(),
collection_path,
&collection_path.join("snapshots"),
)
.await;
// Local replica must be in Active state after loading (all replicas are local)
let loaded_state = loaded_collection.state().await;
for shard_info in loaded_state.shards.values() {
for replica_state in shard_info.replicas.values() {
assert_eq!(replica_state, &ReplicaState::Active);
}
}
}