|
#include <c10/macros/Macros.h> |
|
#include <c10/util/Synchronized.h> |
|
#include <array> |
|
#include <atomic> |
|
#include <mutex> |
|
#include <thread> |
|
|
|
namespace c10 { |
|
|
|
namespace detail { |
|
|
|
struct IncrementRAII final { |
|
public: |
|
explicit IncrementRAII(std::atomic<int32_t>* counter) : _counter(counter) { |
|
_counter->fetch_add(1); |
|
} |
|
|
|
~IncrementRAII() { |
|
_counter->fetch_sub(1); |
|
} |
|
|
|
private: |
|
std::atomic<int32_t>* _counter; |
|
|
|
C10_DISABLE_COPY_AND_ASSIGN(IncrementRAII); |
|
}; |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <class T> |
|
class LeftRight final { |
|
public: |
|
template <class... Args> |
|
explicit LeftRight(const Args&... args) |
|
: _counters{{{0}, {0}}}, |
|
_foregroundCounterIndex(0), |
|
_foregroundDataIndex(0), |
|
_data{{T{args...}, T{args...}}}, |
|
_writeMutex() {} |
|
|
|
|
|
|
|
LeftRight(const LeftRight&) = delete; |
|
LeftRight(LeftRight&&) noexcept = delete; |
|
LeftRight& operator=(const LeftRight&) = delete; |
|
LeftRight& operator=(LeftRight&&) noexcept = delete; |
|
|
|
~LeftRight() { |
|
|
|
{ std::unique_lock<std::mutex> lock(_writeMutex); } |
|
|
|
|
|
while (_counters[0].load() != 0 || _counters[1].load() != 0) { |
|
std::this_thread::yield(); |
|
} |
|
} |
|
|
|
template <typename F> |
|
auto read(F&& readFunc) const { |
|
detail::IncrementRAII _increment_counter( |
|
&_counters[_foregroundCounterIndex.load()]); |
|
|
|
return std::forward<F>(readFunc)(_data[_foregroundDataIndex.load()]); |
|
} |
|
|
|
|
|
|
|
|
|
template <typename F> |
|
auto write(F&& writeFunc) { |
|
std::unique_lock<std::mutex> lock(_writeMutex); |
|
|
|
return _write(std::forward<F>(writeFunc)); |
|
} |
|
|
|
private: |
|
template <class F> |
|
auto _write(const F& writeFunc) { |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
auto localDataIndex = _foregroundDataIndex.load(); |
|
|
|
|
|
_callWriteFuncOnBackgroundInstance(writeFunc, localDataIndex); |
|
|
|
|
|
localDataIndex = localDataIndex ^ 1; |
|
_foregroundDataIndex = localDataIndex; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
auto localCounterIndex = _foregroundCounterIndex.load(); |
|
_waitForBackgroundCounterToBeZero(localCounterIndex); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
localCounterIndex = localCounterIndex ^ 1; |
|
_foregroundCounterIndex = localCounterIndex; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_waitForBackgroundCounterToBeZero(localCounterIndex); |
|
|
|
|
|
return _callWriteFuncOnBackgroundInstance(writeFunc, localDataIndex); |
|
} |
|
|
|
template <class F> |
|
auto _callWriteFuncOnBackgroundInstance( |
|
const F& writeFunc, |
|
uint8_t localDataIndex) { |
|
try { |
|
return writeFunc(_data[localDataIndex ^ 1]); |
|
} catch (...) { |
|
|
|
_data[localDataIndex ^ 1] = _data[localDataIndex]; |
|
|
|
throw; |
|
} |
|
} |
|
|
|
void _waitForBackgroundCounterToBeZero(uint8_t counterIndex) { |
|
while (_counters[counterIndex ^ 1].load() != 0) { |
|
std::this_thread::yield(); |
|
} |
|
} |
|
|
|
mutable std::array<std::atomic<int32_t>, 2> _counters; |
|
std::atomic<uint8_t> _foregroundCounterIndex; |
|
std::atomic<uint8_t> _foregroundDataIndex; |
|
std::array<T, 2> _data; |
|
std::mutex _writeMutex; |
|
}; |
|
|
|
|
|
|
|
template <class T> |
|
class RWSafeLeftRightWrapper final { |
|
public: |
|
template <class... Args> |
|
explicit RWSafeLeftRightWrapper(const Args&... args) : data_{args...} {} |
|
|
|
|
|
|
|
RWSafeLeftRightWrapper(const RWSafeLeftRightWrapper&) = delete; |
|
RWSafeLeftRightWrapper(RWSafeLeftRightWrapper&&) noexcept = delete; |
|
RWSafeLeftRightWrapper& operator=(const RWSafeLeftRightWrapper&) = delete; |
|
RWSafeLeftRightWrapper& operator=(RWSafeLeftRightWrapper&&) noexcept = delete; |
|
|
|
template <typename F> |
|
|
|
auto read(F&& readFunc) const { |
|
return data_.withLock( |
|
[&readFunc](T const& data) { return std::forward<F>(readFunc)(data); }); |
|
} |
|
|
|
template <typename F> |
|
|
|
auto write(F&& writeFunc) { |
|
return data_.withLock( |
|
[&writeFunc](T& data) { return std::forward<F>(writeFunc)(data); }); |
|
} |
|
|
|
private: |
|
c10::Synchronized<T> data_; |
|
}; |
|
|
|
} |
|
|