Spaces:
Build error
Build error
File size: 6,716 Bytes
84d2a97 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 |
use std::any::Any;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, Weak};
use tokio::task::JoinHandle;
type PanicPayload = Box<dyn Any + Send + 'static>;
pub struct StoppableTaskHandle<T> {
pub join_handle: JoinHandle<Option<T>>,
started: Arc<AtomicBool>,
stopped: Weak<AtomicBool>,
panic_handler: Option<Box<dyn Fn(PanicPayload) + Sync + Send>>,
}
impl<T> StoppableTaskHandle<T> {
pub fn is_started(&self) -> bool {
self.started.load(Ordering::Relaxed)
}
pub fn is_finished(&self) -> bool {
self.join_handle.is_finished()
}
pub fn ask_to_stop(&self) {
if let Some(v) = self.stopped.upgrade() {
v.store(true, Ordering::Relaxed);
}
}
pub fn stop(self) -> Option<JoinHandle<Option<T>>> {
self.ask_to_stop();
self.is_started().then_some(self.join_handle)
}
/// Join this stoppable task and handle any panics
///
/// Any panics are propagated through the configured panic handler. If no handler is
/// configured, nothing happens.
///
/// To call this, the task must already be finished. Otherwise it panics in development, or
/// blocks in release.
pub async fn join_and_handle_panic(self) {
debug_assert!(
self.join_handle.is_finished(),
"Task must be finished, we cannot block here on awaiting the join handle",
);
match self.join_handle.await {
Ok(_) => {}
Err(err) if err.is_cancelled() => {}
// Propagate panic
Err(err) if err.is_panic() => match self.panic_handler {
Some(panic_handler) => {
log::trace!("Handling stoppable task panic through custom panic handler");
let panic = err.into_panic();
panic_handler(panic);
}
None => {
log::debug!("Stoppable task panicked without panic handler");
}
},
// Log error on unknown error
Err(err) => {
log::error!("Stoppable task handle error for unknown reason: {err}");
}
}
}
}
/// Spawn stoppable task `f`
///
/// An optional `panic_handler` may be given, eventually called if the task panicked.
pub fn spawn_stoppable<F, T>(
f: F,
panic_handler: Option<Box<dyn Fn(PanicPayload) + Sync + Send>>,
) -> StoppableTaskHandle<T>
where
F: FnOnce(&AtomicBool) -> T + Send + 'static,
T: Send + 'static,
{
let started = Arc::new(AtomicBool::new(false));
let started_c = started.clone();
let stopped = Arc::new(AtomicBool::new(false));
// We are OK if original value is destroyed with the thread
// Weak reference is sufficient
let stopped_w = Arc::downgrade(&stopped);
StoppableTaskHandle {
join_handle: tokio::task::spawn_blocking(move || {
// TODO: Should we use `Ordering::Acquire` or `Ordering::SeqCst`? 🤔
if stopped.load(Ordering::Relaxed) {
return None;
}
// TODO: Should we use `Ordering::Release` or `Ordering::SeqCst`? 🤔
started.store(true, Ordering::Relaxed);
Some(f(&stopped))
}),
started: started_c,
stopped: stopped_w,
panic_handler,
}
}
#[cfg(test)]
mod tests {
use std::sync::Mutex;
use std::thread;
use std::time::{Duration, Instant};
use common::panic;
use tokio::time::sleep;
use super::*;
const STEP: Duration = Duration::from_millis(5);
/// Simple stoppable task counting steps until stopped. Panics after 1 minute.
fn counting_task(stop: &AtomicBool) -> usize {
let mut count = 0;
let start = Instant::now();
while !stop.load(Ordering::SeqCst) {
count += 1;
if start.elapsed() > Duration::from_secs(60) {
panic!("Task is not stopped within 60 seconds");
}
thread::sleep(STEP);
}
count
}
#[tokio::test(flavor = "multi_thread")]
async fn test_task_stop() {
let handle = spawn_stoppable(counting_task, None);
// Signal task to stop after ~20 steps
sleep(STEP * 20).await;
assert!(!handle.is_finished());
handle.ask_to_stop();
sleep(Duration::from_secs(1)).await;
assert!(handle.is_finished());
// Expect task counter to be between [5, 25], we cannot be exact on busy systems
if let Some(handle) = handle.stop() {
if let Some(count) = handle.await.unwrap() {
assert!(
count < 25,
"Stoppable task should have count should be less than 25, but it is {count}",
);
}
}
}
#[tokio::test(flavor = "multi_thread")]
async fn test_task_stop_many() {
const TASKS: usize = 64;
let handles = (0..TASKS)
.map(|_| spawn_stoppable(counting_task, None))
.collect::<Vec<_>>();
// Signal tasks to stop after ~20 steps
sleep(STEP * 20).await;
for handle in &handles {
assert!(!handle.is_finished());
handle.ask_to_stop();
}
// Expect task counters to be between [5, 30], we cannot be exact on busy systems
for handle in handles {
if let Some(handle) = handle.stop() {
if let Some(count) = handle.await.unwrap() {
assert!(
count < 30, // 10 extra steps to stop all tasks
"Stoppable task should have count should be less than 30, but it is {count}",
);
}
}
}
}
#[tokio::test(flavor = "multi_thread")]
async fn test_task_panic() {
let panic_payload = Arc::new(Mutex::new(String::new()));
let handle = spawn_stoppable(
|_| {
thread::sleep(STEP * 50);
panic!("stoppable task panicked");
},
Some(Box::new({
let panic_payload = panic_payload.clone();
move |payload| {
*panic_payload.lock().unwrap() =
panic::downcast_str(&payload).unwrap_or("").into();
}
})),
);
sleep(STEP * 20).await;
assert!(!handle.is_finished());
sleep(STEP * 100).await;
assert!(handle.is_finished());
// Join handle to call back panic
handle.join_and_handle_panic().await;
assert_eq!(*panic_payload.lock().unwrap(), "stoppable task panicked");
}
}
|