use std::fs::File; use std::io::BufWriter; use std::ops::{Deref, DerefMut}; use std::path::{Path, PathBuf}; use std::time::Duration; use atomicwrites::OverwriteBehavior::AllowOverwrite; use atomicwrites::{AtomicFile, Error as AtomicWriteError}; use common::tar_ext; use parking_lot::{Condvar, Mutex, RwLock, RwLockReadGuard, RwLockUpgradableReadGuard}; use serde::{Deserialize, Serialize}; /// Functions as a smart pointer which gives a write guard and saves data on disk /// when write guard is dropped. #[derive(Debug, Default)] pub struct SaveOnDisk { change_notification: Condvar, notification_lock: Mutex<()>, data: RwLock, path: PathBuf, } #[derive(thiserror::Error, Debug)] pub enum Error { #[error("Failed to save structure on disk with error: {0}")] AtomicWrite(#[from] AtomicWriteError), #[error("Failed to perform io operation: {0}")] IoError(#[from] std::io::Error), #[error("Failed to (de)serialize from/to json: {0}")] JsonError(#[from] serde_json::Error), #[error("Error in write closure: {0}")] FromClosure(Box), } impl Deserialize<'de> + Clone> SaveOnDisk { /// Load data from disk at the given path, or initialize the default if it doesn't exist. pub fn load_or_init_default(path: impl Into) -> Result where T: Default, { Self::load_or_init(path, T::default) } /// Load data from disk at the given path, or initialize it with `init` if it doesn't exist. pub fn load_or_init(path: impl Into, init: impl FnOnce() -> T) -> Result { let path: PathBuf = path.into(); let data = if path.exists() { let file = File::open(&path)?; serde_json::from_reader(&file)? } else { init() }; Ok(Self { change_notification: Condvar::new(), notification_lock: Default::default(), data: RwLock::new(data), path, }) } /// Initialize new data, even if it already exists on disk at the given path. /// /// If data already exists on disk, it will be immediately overwritten. pub fn new(path: impl Into, data: T) -> Result { let data = Self { change_notification: Condvar::new(), notification_lock: Default::default(), data: RwLock::new(data), path: path.into(), }; if data.path.exists() { data.save()?; } Ok(data) } /// Wait for a condition on data to be true. /// /// Returns `true` if condition is true, `false` if timed out. #[must_use] pub fn wait_for(&self, check: F, timeout: Duration) -> bool where F: Fn(&T) -> bool, { let start = std::time::Instant::now(); while start.elapsed() < timeout { let mut data_read_guard = self.data.read(); if check(&data_read_guard) { return true; } let notification_guard = self.notification_lock.lock(); // Based on https://github.com/Amanieu/parking_lot/issues/165 RwLockReadGuard::unlocked(&mut data_read_guard, || { // Move the guard in so it gets unlocked before we re-lock g let mut guard = notification_guard; self.change_notification.wait_for(&mut guard, timeout); }); } false } /// Perform an operation over the stored data, /// persisting the result to disk if the operation returns `Some`. /// /// If the operation returns `None`, assumes that data has not changed pub fn write_optional(&self, f: impl FnOnce(&T) -> Option) -> Result { let read_data = self.data.upgradable_read(); let output_opt = f(&read_data); if let Some(output) = output_opt { Self::save_data_to(&self.path, &output)?; let mut write_data = RwLockUpgradableReadGuard::upgrade(read_data); *write_data = output; self.change_notification.notify_all(); Ok(true) } else { Ok(false) } } pub fn write(&self, f: impl FnOnce(&mut T) -> O) -> Result { let read_data = self.data.upgradable_read(); let mut data_copy = (*read_data).clone(); let output = f(&mut data_copy); Self::save_data_to(&self.path, &data_copy)?; let mut write_data = RwLockUpgradableReadGuard::upgrade(read_data); *write_data = data_copy; self.change_notification.notify_all(); Ok(output) } fn save_data_to(path: impl Into, data: &T) -> Result<(), Error> { let path: PathBuf = path.into(); AtomicFile::new(path, AllowOverwrite).write(|file| { let writer = BufWriter::new(file); serde_json::to_writer(writer, data) })?; Ok(()) } pub fn save(&self) -> Result<(), Error> { self.save_to(&self.path) } pub fn save_to(&self, path: impl Into) -> Result<(), Error> { Self::save_data_to(path, &self.data.read()) } pub async fn save_to_tar( &self, tar: &tar_ext::BuilderExt, path: impl AsRef, ) -> Result<(), Error> { let data_bytes = serde_json::to_vec(self.data.read().deref())?; tar.append_data(data_bytes, path.as_ref()).await?; Ok(()) } pub async fn delete(self) -> std::io::Result<()> { tokio::fs::remove_file(self.path).await } } impl Deref for SaveOnDisk { type Target = RwLock; fn deref(&self) -> &Self::Target { &self.data } } impl DerefMut for SaveOnDisk { fn deref_mut(&mut self) -> &mut Self::Target { &mut self.data } } #[cfg(test)] mod tests { use std::sync::Arc; use std::thread::sleep; use std::time::Duration; use std::{fs, thread}; use tempfile::Builder; use super::SaveOnDisk; #[test] fn saves_data() { let dir = Builder::new().prefix("test").tempdir().unwrap(); let counter_file = dir.path().join("counter"); let counter: SaveOnDisk = SaveOnDisk::load_or_init_default(&counter_file).unwrap(); counter.write(|counter| *counter += 1).unwrap(); assert_eq!(*counter.read(), 1); assert_eq!( counter.read().to_string(), fs::read_to_string(&counter_file).unwrap() ); counter.write(|counter| *counter += 1).unwrap(); assert_eq!(*counter.read(), 2); assert_eq!( counter.read().to_string(), fs::read_to_string(&counter_file).unwrap() ); } #[test] fn loads_data() { let dir = Builder::new().prefix("test").tempdir().unwrap(); let counter_file = dir.path().join("counter"); let counter: SaveOnDisk = SaveOnDisk::load_or_init_default(&counter_file).unwrap(); counter.write(|counter| *counter += 1).unwrap(); let counter: SaveOnDisk = SaveOnDisk::load_or_init_default(&counter_file).unwrap(); let value = *counter.read(); assert_eq!(value, 1) } #[test] fn test_wait_for_condition_change() { let dir = Builder::new().prefix("test").tempdir().unwrap(); let counter_file = dir.path().join("counter"); let counter: Arc> = Arc::new(SaveOnDisk::load_or_init_default(counter_file).unwrap()); let counter_copy = counter.clone(); let handle = thread::spawn(move || { sleep(Duration::from_millis(200)); counter_copy.write(|counter| *counter += 3).unwrap(); sleep(Duration::from_millis(200)); counter_copy.write(|counter| *counter += 7).unwrap(); sleep(Duration::from_millis(200)); }); assert!(counter.wait_for(|counter| *counter > 5, Duration::from_secs(2))); handle.join().unwrap(); } #[test] fn test_wait_for_condition_change_timeout() { let dir = Builder::new().prefix("test").tempdir().unwrap(); let counter_file = dir.path().join("counter"); let counter: Arc> = Arc::new(SaveOnDisk::load_or_init_default(counter_file).unwrap()); let counter_copy = counter.clone(); let handle = thread::spawn(move || { sleep(Duration::from_millis(200)); counter_copy.write(|counter| *counter += 3).unwrap(); sleep(Duration::from_millis(200)); counter_copy.write(|counter| *counter += 7).unwrap(); sleep(Duration::from_millis(200)); }); assert!(!counter.wait_for(|counter| *counter > 5, Duration::from_millis(300))); handle.join().unwrap(); } }