File size: 2,554 Bytes
84d2a97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
#[allow(unused)]
mod metrics;

#[cfg(test)]
mod tests {
    use std::sync::atomic::{AtomicBool, Ordering};
    use std::sync::Arc;

    use quantization::encoded_vectors::{DistanceType, VectorParameters};
    use quantization::encoded_vectors_u8::EncodedVectorsU8;
    use quantization::{EncodedVectorsPQ, EncodingError};

    #[test]
    fn stop_condition_u8() {
        let stopped = Arc::new(AtomicBool::new(false));
        let stopped_clone = stopped.clone();
        let stopped_ref = stopped.as_ref();

        let stop_thread = std::thread::spawn(move || {
            std::thread::sleep(std::time::Duration::from_millis(100));
            stopped_clone.store(true, Ordering::Relaxed);
        });

        let vectors_count = 1_000_000;
        let vector_dim = 8;
        let vector_parameters = VectorParameters {
            dim: vector_dim,
            count: vectors_count,
            distance_type: DistanceType::Dot,
            invert: false,
        };
        let zero_vector = vec![0.0; vector_dim];

        assert!(
            EncodedVectorsU8::encode(
                (0..vector_parameters.count).map(|_| &zero_vector),
                Vec::<u8>::new(),
                &vector_parameters,
                None,
                || stopped_ref.load(Ordering::Relaxed),
            )
            .err()
                == Some(EncodingError::Stopped)
        );

        stop_thread.join().unwrap();
    }

    #[test]
    fn stop_condition_pq() {
        let stopped = Arc::new(AtomicBool::new(false));
        let stopped_clone = stopped.clone();
        let stopped_ref = stopped.as_ref();

        let stop_thread = std::thread::spawn(move || {
            std::thread::sleep(std::time::Duration::from_millis(300));
            stopped_clone.store(true, Ordering::Relaxed);
        });

        let vectors_count = 1_000_000;
        let vector_dim = 8;
        let vector_parameters = VectorParameters {
            dim: vector_dim,
            count: vectors_count,
            distance_type: DistanceType::Dot,
            invert: false,
        };
        let zero_vector = vec![0.0; vector_dim];

        assert!(
            EncodedVectorsPQ::encode(
                (0..vector_parameters.count).map(|_| &zero_vector),
                Vec::<u8>::new(),
                &vector_parameters,
                2,
                1,
                || stopped_ref.load(Ordering::Relaxed),
            )
            .err()
                == Some(EncodingError::Stopped)
        );

        stop_thread.join().unwrap();
    }
}