File size: 3,390 Bytes
84d2a97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
use collection::collection::distance_matrix::CollectionSearchMatrixRequest;
use collection::operations::point_ops::{
    BatchPersisted, BatchVectorStructPersisted, WriteOrdering,
};
use collection::operations::shard_selector_internal::ShardSelectorInternal;
use common::counter::hardware_accumulator::HwMeasurementAcc;
use itertools::Itertools;
use rand::prelude::SmallRng;
use rand::{Rng, SeedableRng};
use tempfile::Builder;

use crate::common::simple_collection_fixture;

const SEED: u64 = 42;

#[tokio::test(flavor = "multi_thread")]
async fn distance_matrix_empty() {
    let collection_dir = Builder::new().prefix("storage").tempdir().unwrap();

    // empty collection
    let collection = simple_collection_fixture(collection_dir.path(), 1).await;

    let hw_acc = HwMeasurementAcc::new();
    let sample_size = 100;
    let limit_per_sample = 10;
    let request = CollectionSearchMatrixRequest {
        sample_size,
        limit_per_sample,
        filter: None,
        using: "".to_string(), // default vector name
    };
    let matrix = collection
        .search_points_matrix(request, ShardSelectorInternal::All, None, None, &hw_acc)
        .await
        .unwrap();
    hw_acc.discard();

    // assert all empty
    assert!(matrix.sample_ids.is_empty());
    assert!(matrix.nearests.is_empty());
}

#[tokio::test(flavor = "multi_thread")]
async fn distance_matrix_anonymous_vector() {
    let collection_dir = Builder::new().prefix("storage").tempdir().unwrap();

    let collection = simple_collection_fixture(collection_dir.path(), 1).await;

    let point_count = 2000;
    let ids = (0..point_count).map_into().collect();
    let mut rng = SmallRng::seed_from_u64(SEED);

    let vectors = (0..point_count)
        .map(|_| rng.gen::<[f32; 4]>().to_vec())
        .collect_vec();

    let batch = BatchPersisted {
        ids,
        vectors: BatchVectorStructPersisted::Single(vectors),
        payloads: None,
    };

    let upsert_points = collection::operations::CollectionUpdateOperations::PointOperation(
        collection::operations::point_ops::PointOperations::UpsertPoints(
            collection::operations::point_ops::PointInsertOperationsInternal::from(batch),
        ),
    );

    collection
        .update_from_client_simple(upsert_points, true, WriteOrdering::default())
        .await
        .unwrap();

    let hw_acc = HwMeasurementAcc::new();
    let sample_size = 100;
    let limit_per_sample = 10;
    let request = CollectionSearchMatrixRequest {
        sample_size,
        limit_per_sample,
        filter: None,
        using: "".to_string(), // default vector name
    };
    let matrix = collection
        .search_points_matrix(request, ShardSelectorInternal::All, None, None, &hw_acc)
        .await
        .unwrap();
    hw_acc.discard();

    assert_eq!(matrix.sample_ids.len(), sample_size);
    // no duplicate sample ids
    assert_eq!(
        matrix
            .sample_ids
            .iter()
            .collect::<std::collections::HashSet<_>>()
            .len(),
        sample_size
    );

    assert_eq!(matrix.nearests.len(), sample_size);
    for nearest in matrix.nearests {
        assert_eq!(nearest.len(), limit_per_sample);
        // assert each row sorted by scores
        nearest.iter().tuple_windows().for_each(|(prev, next)| {
            assert!(prev.score >= next.score);
        });
    }
}