Spaces:
Build error
Build error
File size: 3,390 Bytes
84d2a97 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 |
use collection::collection::distance_matrix::CollectionSearchMatrixRequest;
use collection::operations::point_ops::{
BatchPersisted, BatchVectorStructPersisted, WriteOrdering,
};
use collection::operations::shard_selector_internal::ShardSelectorInternal;
use common::counter::hardware_accumulator::HwMeasurementAcc;
use itertools::Itertools;
use rand::prelude::SmallRng;
use rand::{Rng, SeedableRng};
use tempfile::Builder;
use crate::common::simple_collection_fixture;
const SEED: u64 = 42;
#[tokio::test(flavor = "multi_thread")]
async fn distance_matrix_empty() {
let collection_dir = Builder::new().prefix("storage").tempdir().unwrap();
// empty collection
let collection = simple_collection_fixture(collection_dir.path(), 1).await;
let hw_acc = HwMeasurementAcc::new();
let sample_size = 100;
let limit_per_sample = 10;
let request = CollectionSearchMatrixRequest {
sample_size,
limit_per_sample,
filter: None,
using: "".to_string(), // default vector name
};
let matrix = collection
.search_points_matrix(request, ShardSelectorInternal::All, None, None, &hw_acc)
.await
.unwrap();
hw_acc.discard();
// assert all empty
assert!(matrix.sample_ids.is_empty());
assert!(matrix.nearests.is_empty());
}
#[tokio::test(flavor = "multi_thread")]
async fn distance_matrix_anonymous_vector() {
let collection_dir = Builder::new().prefix("storage").tempdir().unwrap();
let collection = simple_collection_fixture(collection_dir.path(), 1).await;
let point_count = 2000;
let ids = (0..point_count).map_into().collect();
let mut rng = SmallRng::seed_from_u64(SEED);
let vectors = (0..point_count)
.map(|_| rng.gen::<[f32; 4]>().to_vec())
.collect_vec();
let batch = BatchPersisted {
ids,
vectors: BatchVectorStructPersisted::Single(vectors),
payloads: None,
};
let upsert_points = collection::operations::CollectionUpdateOperations::PointOperation(
collection::operations::point_ops::PointOperations::UpsertPoints(
collection::operations::point_ops::PointInsertOperationsInternal::from(batch),
),
);
collection
.update_from_client_simple(upsert_points, true, WriteOrdering::default())
.await
.unwrap();
let hw_acc = HwMeasurementAcc::new();
let sample_size = 100;
let limit_per_sample = 10;
let request = CollectionSearchMatrixRequest {
sample_size,
limit_per_sample,
filter: None,
using: "".to_string(), // default vector name
};
let matrix = collection
.search_points_matrix(request, ShardSelectorInternal::All, None, None, &hw_acc)
.await
.unwrap();
hw_acc.discard();
assert_eq!(matrix.sample_ids.len(), sample_size);
// no duplicate sample ids
assert_eq!(
matrix
.sample_ids
.iter()
.collect::<std::collections::HashSet<_>>()
.len(),
sample_size
);
assert_eq!(matrix.nearests.len(), sample_size);
for nearest in matrix.nearests {
assert_eq!(nearest.len(), limit_per_sample);
// assert each row sorted by scores
nearest.iter().tuple_windows().for_each(|(prev, next)| {
assert!(prev.score >= next.score);
});
}
}
|