Spaces:
Build error
Build error
//! Reciprocal Rank Fusion (RRF) is a method for combining rankings from multiple sources. | |
//! See https://plg.uwaterloo.ca/~gvcormac/cormacksigir09-rrf.pdf | |
use std::collections::hash_map::Entry; | |
use ahash::{HashMap, HashMapExt}; | |
use ordered_float::OrderedFloat; | |
use crate::types::{ExtendedPointId, ScoredPoint}; | |
/// Mitigates the impact of high rankings by outlier systems | |
const RFF_RANKING_K: f32 = 2.0; | |
/// Compute the RRF score for a given position. | |
fn position_score(position: usize) -> f32 { | |
1.0 / (position as f32 + RFF_RANKING_K) | |
} | |
/// Compute RRF scores for multiple results from different sources. | |
/// Each response can have a different length. | |
/// The input scores are irrelevant, only the order matters. | |
/// | |
/// The output is a single sorted list of ScoredPoint. | |
/// Does not break ties. | |
pub fn rrf_scoring(responses: impl IntoIterator<Item = Vec<ScoredPoint>>) -> Vec<ScoredPoint> { | |
// track scored points by id | |
let mut points_by_id: HashMap<ExtendedPointId, ScoredPoint> = HashMap::new(); | |
for response in responses { | |
for (pos, mut point) in response.into_iter().enumerate() { | |
let rrf_score = position_score(pos); | |
match points_by_id.entry(point.id) { | |
Entry::Occupied(mut entry) => { | |
// accumulate score | |
entry.get_mut().score += rrf_score; | |
} | |
Entry::Vacant(entry) => { | |
point.score = rrf_score; | |
// init score | |
entry.insert(point); | |
} | |
} | |
} | |
} | |
let mut scores: Vec<_> = points_by_id.into_values().collect(); | |
scores.sort_unstable_by(|a, b| { | |
// sort by score descending | |
OrderedFloat(b.score).cmp(&OrderedFloat(a.score)) | |
}); | |
scores | |
} | |
mod tests { | |
use super::*; | |
use crate::types::ScoredPoint; | |
fn make_scored_point(id: u64, score: f32) -> ScoredPoint { | |
ScoredPoint { | |
id: id.into(), | |
version: 0, | |
score, | |
payload: None, | |
vector: None, | |
shard_key: None, | |
order_value: None, | |
} | |
} | |
fn test_rrf_scoring_empty() { | |
let responses = vec![]; | |
let scored_points = rrf_scoring(responses); | |
assert_eq!(scored_points.len(), 0); | |
} | |
fn test_rrf_scoring_one() { | |
let responses = vec![vec![make_scored_point(1, 0.9)]]; | |
let scored_points = rrf_scoring(responses); | |
assert_eq!(scored_points.len(), 1); | |
assert_eq!(scored_points[0].id, 1.into()); | |
assert_eq!(scored_points[0].score, 0.5); // 1 / (0 + 2) | |
} | |
fn test_rrf_scoring() { | |
let responses = vec![ | |
vec![make_scored_point(2, 0.9), make_scored_point(1, 0.8)], | |
vec![ | |
make_scored_point(1, 0.7), | |
make_scored_point(2, 0.6), | |
make_scored_point(3, 0.5), | |
], | |
vec![ | |
make_scored_point(5, 0.9), | |
make_scored_point(3, 0.5), | |
make_scored_point(1, 0.4), | |
], | |
]; | |
// top 10 | |
let scored_points = rrf_scoring(responses); | |
assert_eq!(scored_points.len(), 4); | |
// assert that the list is sorted | |
assert!(scored_points.windows(2).all(|w| w[0].score >= w[1].score)); | |
assert_eq!(scored_points.len(), 4); | |
assert_eq!(scored_points[0].id, 1.into()); | |
assert_eq!(scored_points[0].score, 1.0833334); | |
assert_eq!(scored_points[1].id, 2.into()); | |
assert_eq!(scored_points[1].score, 0.8333334); | |
assert_eq!(scored_points[2].id, 3.into()); | |
assert_eq!(scored_points[2].score, 0.5833334); | |
assert_eq!(scored_points[3].id, 5.into()); | |
assert_eq!(scored_points[3].score, 0.5); | |
} | |
} | |