File size: 2,861 Bytes
84d2a97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
use std::future::Future;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;

use parking_lot::Mutex;
use tokio::task::JoinHandle;
use tokio_util::sync::CancellationToken;

pub struct CancellableAsyncTaskHandle<T: Clone> {
    pub join_handle: JoinHandle<T>,
    result_holder: Arc<Mutex<Option<T>>>,
    cancelled: CancellationToken,
    finished: Arc<AtomicBool>,
}

impl<T: Clone> CancellableAsyncTaskHandle<T> {
    pub fn is_finished(&self) -> bool {
        self.finished.load(Ordering::Relaxed)
    }

    pub fn ask_to_cancel(&self) {
        self.cancelled.cancel();
    }

    pub fn cancel(self) -> JoinHandle<T> {
        self.ask_to_cancel();
        self.join_handle
    }

    pub fn get_result(&self) -> Option<T> {
        self.result_holder.lock().clone()
    }
}

pub fn spawn_async_cancellable<F, T>(f: F) -> CancellableAsyncTaskHandle<T::Output>
where
    F: FnOnce(CancellationToken) -> T,
    F: Send + 'static,
    T: Future + Send + 'static,
    T::Output: Clone + Send + 'static,
{
    let cancelled = CancellationToken::new();
    let finished = Arc::new(AtomicBool::new(false));
    let result_holder = Arc::new(Mutex::new(None));

    CancellableAsyncTaskHandle {
        join_handle: tokio::task::spawn({
            let (cancel, finished, result_holder) =
                (cancelled.clone(), finished.clone(), result_holder.clone());
            async move {
                let res = f(cancel).await;
                let mut result_holder_w = result_holder.lock();
                result_holder_w.replace(res.clone());

                // We use `Release` ordering to ensure that `f` won't be moved after the `store`
                // by the compiler
                finished.store(true, Ordering::Release);
                res
            }
        }),
        result_holder,
        cancelled,
        finished,
    }
}

#[cfg(test)]
mod tests {
    use std::time::Duration;

    use tokio::time::sleep;

    use super::*;

    const STEP_MILLIS: u64 = 5;

    async fn long_task(cancel: CancellationToken) -> i32 {
        let mut n = 0;
        for i in 0..10 {
            n = i;
            if cancel.is_cancelled() {
                break;
            }
            sleep(Duration::from_millis(STEP_MILLIS)).await;
        }
        n
    }

    #[tokio::test]
    async fn test_task_stop() {
        let handle = spawn_async_cancellable(long_task);

        sleep(Duration::from_millis(STEP_MILLIS * 5)).await;
        assert!(!handle.is_finished());
        handle.ask_to_cancel();
        sleep(Duration::from_millis(STEP_MILLIS * 3)).await;
        // If windows, we need to wait a bit more
        #[cfg(windows)]
        sleep(Duration::from_millis(STEP_MILLIS * 10)).await;
        assert!(handle.is_finished());

        let res = handle.cancel().await.unwrap();
        assert!(res < 10);
    }
}