Spaces:
Build error
Build error
File size: 5,827 Bytes
84d2a97 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 |
use std::collections::HashMap;
use std::iter;
use common::types::ScoreType;
use itertools::{Itertools, MinMaxResult};
use ordered_float::OrderedFloat;
use crate::types::{Order, PointIdType, ScoredPoint};
pub struct ScoreFusion {
/// Defines how to combine the scores of the same point in different lists
pub method: Aggregation,
/// Defines how to normalize the scores in each list
pub norm: Normalization,
/// Multipliers for each list of scores
pub weights: Vec<f32>,
/// Final ordering of the results
pub order: Order,
}
impl ScoreFusion {
/// Params for the distribution-based score fusion
pub fn dbsf() -> Self {
Self {
method: Aggregation::Sum,
norm: Normalization::Distr,
weights: vec![],
order: Order::LargeBetter,
}
}
}
/// Defines how to combine the scores of the same point in different lists
pub enum Aggregation {
/// Sums the scores
Sum,
}
pub enum Normalization {
/// Uses the minimum and maximum scores as extremes
MinMax,
/// Uses the 3rd standard deviation as extremes
Distr,
}
pub fn score_fusion(
all_results: impl IntoIterator<Item = Vec<ScoredPoint>>,
params: ScoreFusion,
) -> Vec<ScoredPoint> {
let ScoreFusion {
method,
norm,
weights,
order,
} = params;
let weights = weights.into_iter().chain(iter::repeat(1.0));
all_results
.into_iter()
// normalize
.map(|points| match norm {
Normalization::MinMax => min_max_norm(points),
Normalization::Distr => distr_norm(points),
})
// weight each list of points
.zip(weights)
.flat_map(|(points, weight)| {
points.into_iter().map(move |p| ScoredPoint {
score: p.score * weight,
..p
})
})
// combine to deduplicate
.fold(
HashMap::<PointIdType, ScoredPoint>::new(),
|mut acc, point| {
acc.entry(point.id)
.and_modify(|entry| match method {
Aggregation::Sum => entry.score += point.score,
})
.or_insert(point);
acc
},
)
// sort and return
.into_values()
.sorted_by(|a, b| match order {
Order::SmallBetter => a.cmp(b),
Order::LargeBetter => b.cmp(a),
})
.collect()
}
/// Normalizes the scores of the given points between 0.0 and 1.0, using the given minimum and maximum scores as extremes.
fn norm(mut points: Vec<ScoredPoint>, min: ScoreType, max: ScoreType) -> Vec<ScoredPoint> {
// Protect against division by zero
if min == max {
points.iter_mut().for_each(|p| p.score = 0.5);
return points;
}
points.iter_mut().for_each(|p| {
p.score = (p.score - min) / (max - min);
});
points
}
pub fn min_max_norm(points: Vec<ScoredPoint>) -> Vec<ScoredPoint> {
let (min, max) = match points.iter().map(|p| OrderedFloat(p.score)).minmax() {
MinMaxResult::NoElements | MinMaxResult::OneElement(_) => return points,
MinMaxResult::MinMax(min, max) => (min.0, max.0),
};
norm(points, min, max)
}
/// Welford's method for stable one-pass mean and variance calculation.
/// https://jonisalonen.com/2013/deriving-welfords-method-for-computing-variance/
///
/// # Panics
///
/// Panics if the given vector of points has less than 2 elements.
fn welfords_mean_variance(points: &[ScoredPoint]) -> (f32, f32) {
debug_assert!(
points.len() > 1,
"Not enough points to calculate mean and variance"
);
let mut mean = 0.0;
let mut aggregate = 0.0;
for (p, k) in points.iter().zip(1usize..) {
let old_delta = p.score - mean;
mean += old_delta / (k as f32);
let delta = p.score - mean;
aggregate += (old_delta) * (delta);
}
let sample_variance = aggregate / (points.len() as f32 - 1.0);
(mean, sample_variance)
}
/// Estimates the mean and variance of the given points and normalizes them between 0.0 and 1.0, using the 3rd
/// standard deviation as extremes.
pub fn distr_norm(mut points: Vec<ScoredPoint>) -> Vec<ScoredPoint> {
if points.len() < 2 {
if points.len() == 1 {
points[0].score = 0.5;
}
return points;
}
let (mean, variance) = welfords_mean_variance(&points);
let std_dev = variance.sqrt();
let min = mean - 3.0 * std_dev;
let max = mean + 3.0 * std_dev;
norm(points, min, max)
}
#[cfg(test)]
mod tests {
use proptest::prelude::*;
use super::*;
fn point(id: usize, score: ScoreType) -> ScoredPoint {
ScoredPoint {
id: PointIdType::NumId(id as u64),
version: 0,
score,
payload: None,
vector: None,
shard_key: None,
order_value: None,
}
}
fn assert_close(a: f32, b: f32) {
assert!((a - b).abs() < 1e-2, "{a} is not close to {b}");
}
proptest! {
#[test]
fn welford_calc_vs_naive(scores in prop::collection::vec(-100.0..100.0f32, 2..1000)) {
let naive_mean = scores.iter().sum::<f32>() / scores.len() as f32;
let naive_variance = scores.iter().map(|p| (p - naive_mean).powi(2)).sum::<f32>()
/ (scores.len() - 1) as f32;
let points = scores
.into_iter()
.enumerate()
.map(|(i, s)| point(i, s))
.collect_vec();
let (mean, variance) = welfords_mean_variance(&points);
assert_close(mean, naive_mean);
assert_close(variance, naive_variance);
}
}
}
|