|
#pragma once |
|
|
|
#include <cstddef> |
|
#if defined(_MSC_VER) |
|
#include <intrin.h> |
|
#endif |
|
|
|
namespace c10::utils { |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
struct bitset final { |
|
private: |
|
#if defined(_MSC_VER) |
|
|
|
using bitset_type = int64_t; |
|
#else |
|
|
|
using bitset_type = long long int; |
|
#endif |
|
public: |
|
static constexpr size_t NUM_BITS() { |
|
return 8 * sizeof(bitset_type); |
|
} |
|
|
|
constexpr bitset() noexcept = default; |
|
constexpr bitset(const bitset&) noexcept = default; |
|
constexpr bitset(bitset&&) noexcept = default; |
|
|
|
|
|
bitset& operator=(const bitset&) noexcept = default; |
|
bitset& operator=(bitset&&) noexcept = default; |
|
|
|
constexpr void set(size_t index) noexcept { |
|
bitset_ |= (static_cast<long long int>(1) << index); |
|
} |
|
|
|
constexpr void unset(size_t index) noexcept { |
|
bitset_ &= ~(static_cast<long long int>(1) << index); |
|
} |
|
|
|
constexpr bool get(size_t index) const noexcept { |
|
return bitset_ & (static_cast<long long int>(1) << index); |
|
} |
|
|
|
constexpr bool is_entirely_unset() const noexcept { |
|
return 0 == bitset_; |
|
} |
|
|
|
|
|
template <class Func> |
|
void for_each_set_bit(Func&& func) const { |
|
bitset cur = *this; |
|
size_t index = cur.find_first_set(); |
|
while (0 != index) { |
|
|
|
index -= 1; |
|
func(index); |
|
cur.unset(index); |
|
index = cur.find_first_set(); |
|
} |
|
} |
|
|
|
private: |
|
|
|
|
|
|
|
size_t find_first_set() const { |
|
#if defined(_MSC_VER) && (defined(_M_X64) || defined(_M_ARM64)) |
|
unsigned long result; |
|
bool has_bits_set = (0 != _BitScanForward64(&result, bitset_)); |
|
if (!has_bits_set) { |
|
return 0; |
|
} |
|
return result + 1; |
|
#elif defined(_MSC_VER) && defined(_M_IX86) |
|
unsigned long result; |
|
if (static_cast<uint32_t>(bitset_) != 0) { |
|
bool has_bits_set = |
|
(0 != _BitScanForward(&result, static_cast<uint32_t>(bitset_))); |
|
if (!has_bits_set) { |
|
return 0; |
|
} |
|
return result + 1; |
|
} else { |
|
bool has_bits_set = |
|
(0 != _BitScanForward(&result, static_cast<uint32_t>(bitset_ >> 32))); |
|
if (!has_bits_set) { |
|
return 32; |
|
} |
|
return result + 33; |
|
} |
|
#else |
|
return __builtin_ffsll(bitset_); |
|
#endif |
|
} |
|
|
|
friend bool operator==(bitset lhs, bitset rhs) noexcept { |
|
return lhs.bitset_ == rhs.bitset_; |
|
} |
|
|
|
bitset_type bitset_{0}; |
|
}; |
|
|
|
inline bool operator!=(bitset lhs, bitset rhs) noexcept { |
|
return !(lhs == rhs); |
|
} |
|
|
|
} |
|
|