File size: 3,868 Bytes
84d2a97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
//! Reciprocal Rank Fusion (RRF) is a method for combining rankings from multiple sources.
//! See https://plg.uwaterloo.ca/~gvcormac/cormacksigir09-rrf.pdf

use std::collections::hash_map::Entry;

use ahash::{HashMap, HashMapExt};
use ordered_float::OrderedFloat;

use crate::types::{ExtendedPointId, ScoredPoint};

/// Mitigates the impact of high rankings by outlier systems
const RFF_RANKING_K: f32 = 2.0;

/// Compute the RRF score for a given position.
fn position_score(position: usize) -> f32 {
    1.0 / (position as f32 + RFF_RANKING_K)
}

/// Compute RRF scores for multiple results from different sources.
/// Each response can have a different length.
/// The input scores are irrelevant, only the order matters.
///
/// The output is a single sorted list of ScoredPoint.
/// Does not break ties.
pub fn rrf_scoring(responses: impl IntoIterator<Item = Vec<ScoredPoint>>) -> Vec<ScoredPoint> {
    // track scored points by id
    let mut points_by_id: HashMap<ExtendedPointId, ScoredPoint> = HashMap::new();

    for response in responses {
        for (pos, mut point) in response.into_iter().enumerate() {
            let rrf_score = position_score(pos);
            match points_by_id.entry(point.id) {
                Entry::Occupied(mut entry) => {
                    // accumulate score
                    entry.get_mut().score += rrf_score;
                }
                Entry::Vacant(entry) => {
                    point.score = rrf_score;
                    // init score
                    entry.insert(point);
                }
            }
        }
    }

    let mut scores: Vec<_> = points_by_id.into_values().collect();
    scores.sort_unstable_by(|a, b| {
        // sort by score descending
        OrderedFloat(b.score).cmp(&OrderedFloat(a.score))
    });

    scores
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::types::ScoredPoint;

    fn make_scored_point(id: u64, score: f32) -> ScoredPoint {
        ScoredPoint {
            id: id.into(),
            version: 0,
            score,
            payload: None,
            vector: None,
            shard_key: None,
            order_value: None,
        }
    }

    #[test]
    fn test_rrf_scoring_empty() {
        let responses = vec![];
        let scored_points = rrf_scoring(responses);
        assert_eq!(scored_points.len(), 0);
    }

    #[test]
    fn test_rrf_scoring_one() {
        let responses = vec![vec![make_scored_point(1, 0.9)]];
        let scored_points = rrf_scoring(responses);
        assert_eq!(scored_points.len(), 1);
        assert_eq!(scored_points[0].id, 1.into());
        assert_eq!(scored_points[0].score, 0.5); // 1 / (0 + 2)
    }

    #[test]
    fn test_rrf_scoring() {
        let responses = vec![
            vec![make_scored_point(2, 0.9), make_scored_point(1, 0.8)],
            vec![
                make_scored_point(1, 0.7),
                make_scored_point(2, 0.6),
                make_scored_point(3, 0.5),
            ],
            vec![
                make_scored_point(5, 0.9),
                make_scored_point(3, 0.5),
                make_scored_point(1, 0.4),
            ],
        ];

        // top 10
        let scored_points = rrf_scoring(responses);
        assert_eq!(scored_points.len(), 4);
        // assert that the list is sorted
        assert!(scored_points.windows(2).all(|w| w[0].score >= w[1].score));

        assert_eq!(scored_points.len(), 4);
        assert_eq!(scored_points[0].id, 1.into());
        assert_eq!(scored_points[0].score, 1.0833334);

        assert_eq!(scored_points[1].id, 2.into());
        assert_eq!(scored_points[1].score, 0.8333334);

        assert_eq!(scored_points[2].id, 3.into());
        assert_eq!(scored_points[2].score, 0.5833334);

        assert_eq!(scored_points[3].id, 5.into());
        assert_eq!(scored_points[3].score, 0.5);
    }
}