File size: 4,477 Bytes
84d2a97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
use crate::operations::types::CollectionResult;

/// The function performs batching processing of read requests, that some arbitrary key
///
/// Functions are split into sequential subgroups based on the shard key.
/// There are two customizable aggregation functions:
///
/// * `accumulate_local` - called for each request to form a subgroup
/// * `accumulate_global` - called for each subgroup accumulated by `accumulate_local`
///
/// The function returns the result of the last call of `accumulate_global` function.
///
///
/// Example usage (simplified):
///
/// ```python
/// requests = [
///     Recommend(positive=[1], shard_key="cats"),
///     Recommend(positive=[2], shard_key="cats"),
///     Recommend(positive=[3], shard_key="dogs"),
///     Recommend(positive=[3], shard_key="dogs"),
/// ]
/// ```
///
/// We want to:
///     1. Group requests by shard_key into Acc1 (vector of requests)
///     2. Execute each group of requests and push the result into Acc2 (vector of results)
///
/// How to do that:
///
/// ```rust,ignore
/// batch_requests::<
///     Recommend,             // Type of request
///     String,                // Type of shard_key
///     Vec<Recommend>,        // Type of local accumulator
///     Vec<Vec<ScoredPoint>>, // Type of global accumulator,
/// >(
///     requests,
///     |request| &request.shard_key,
///     |request, local_accumulator| { // Accumulate requests
///         local_accumulator.push(request);
///         // Note: we can have more complex logic here
///         // E.g. extracting IDs from requests and de-duplicating them
///         Ok(())
///     },
///     |shard_key, local_accumulator, global_accumulator| { // Execute requests and accumulate results
///        let result = execute_recommendations(local_accumulator, shard_key);
///        global_accumulator.push(result);
///        Ok(())
///     }
/// )
/// ```
pub fn batch_requests<Req, Key: PartialEq + Clone, Acc1: Default, Acc2: Default>(
    requests: impl IntoIterator<Item = Req>,
    get_key: impl Fn(&Req) -> &Key,
    mut accumulate_local: impl FnMut(Req, &mut Acc1) -> CollectionResult<()>,
    mut accumulate_global: impl FnMut(Key, Acc1, &mut Acc2) -> CollectionResult<()>,
) -> CollectionResult<Acc2> {
    let mut local_accumulator = Acc1::default();
    let mut global_accumulator = Acc2::default();
    let mut prev_key = None;
    for request in requests {
        let request_key = get_key(&request);
        if let Some(ref pk) = prev_key {
            if request_key != pk {
                accumulate_global(pk.clone(), local_accumulator, &mut global_accumulator)?;
                prev_key = Some(request_key.clone());
                local_accumulator = Acc1::default();
            }
        } else {
            prev_key = Some(request_key.clone());
        }
        accumulate_local(request, &mut local_accumulator)?;
    }
    if let Some(prev_key) = prev_key {
        accumulate_global(prev_key, local_accumulator, &mut global_accumulator)?;
    }
    Ok(global_accumulator)
}

#[cfg(test)]
mod tests {
    use super::*;

    fn run_batch_requests(requests: &[(char, usize)]) -> Vec<(char, Vec<(char, usize)>)> {
        batch_requests::<(char, usize), char, Vec<(char, usize)>, Vec<(char, Vec<(char, usize)>)>>(
            requests.iter().copied(),
            |req| &req.0,
            |req, acc1| {
                acc1.push(req);
                Ok(())
            },
            |key, acc1, acc2| {
                acc2.push((key, acc1));
                Ok(())
            },
        )
        .unwrap()
    }

    #[test]
    fn test_batch_requests() {
        assert_eq!(
            run_batch_requests(&[('a', 1), ('b', 2), ('c', 3)]),
            vec![
                ('a', vec![('a', 1)]),
                ('b', vec![('b', 2)]),
                ('c', vec![('c', 3)]),
            ]
        );

        assert_eq!(
            run_batch_requests(&[('a', 1), ('a', 2), ('b', 3), ('b', 4), ('c', 5), ('c', 6)]),
            vec![
                ('a', vec![('a', 1), ('a', 2)]),
                ('b', vec![('b', 3), ('b', 4)]),
                ('c', vec![('c', 5), ('c', 6)]),
            ]
        );

        assert_eq!(
            run_batch_requests(&[('a', 1), ('b', 2), ('a', 3)]),
            vec![
                ('a', vec![('a', 1)]),
                ('b', vec![('b', 2)]),
                ('a', vec![('a', 3)]),
            ]
        );

        assert!(run_batch_requests(&[]).is_empty());
    }
}