#ifndef DLIB_FFT_SIZE_H #define DLIB_FFT_SIZE_H #include <array> #include <algorithm> #include <numeric> #include "../assert.h" #include "../hash.h" namespace dlib { class fft_size { /*! WHAT THIS OBJECT REPRESENTS This object is a container used to store the dimensions of an FFT operation. It is implemented as a stack-based container with an upper bound of 5 dimensions (batch,channels,height,width,depth). All dimensions must be strictly positive. The object is either default constructed, constructed with an initialiser list or with a pair of iterators If default-constructed, the object is empty and in an invalid state. That is, FFT functions will throw if attempted to be used with such an object. If constructed with an initialiser list L, the object is properly initialised provided: - L.size() > 0 and L.size() <= 5 - L contains strictly positive values If constructed with a pair of iterators, the behaviour of the constructor is exactly the same as if constructed with an initializer list spanned by those iterators. Once the object is constructed, it is immutable. !*/ public: using container_type = std::array<long,5>; using const_reference = container_type::const_reference; using iterator = container_type::iterator; using const_iterator = container_type::const_iterator; fft_size() = default; /*! ensures - *this is properly initialised - num_dims() == 0 !*/ template<typename ConstIterator> fft_size(ConstIterator dims_begin, ConstIterator dims_end) /*! requires - ConstIterator is const iterator type that points to a long object - std::distance(dims_begin, dims_end) > 0 - std::distance(dims_begin, dims_end) <= 5 - range contains strictly positive values ensures - *this is properly initialised - num_dims() == std::distance(dims_begin, dims_end) - num_elements() == product of all values in range !*/ { const size_t ndims = std::distance(dims_begin, dims_end); DLIB_ASSERT(ndims > 0, "fft_size objects must be non-empty"); DLIB_ASSERT(ndims <= _dims.size(), "fft_size objects must have size less than 6"); DLIB_ASSERT(std::find_if(dims_begin, dims_end, [](long dim) {return dim <= 0;}) == dims_end, "fft_size objects must contain strictly positive values"); std::copy(dims_begin, dims_end, _dims.begin()); _size = ndims; _num_elements = std::accumulate(dims_begin, dims_end, 1, std::multiplies<long>()); } fft_size(std::initializer_list<long> dims) : fft_size(dims.begin(), dims.end()) /*! requires - dims.size() > 0 and dims.size() <= 5 - dims contains strictly positive values ensures - *this is properly initialised - num_dims() == dims.size() - num_elements() == product of all values in dims !*/ { } size_t num_dims() const /*! ensures - returns the number of dimensions !*/ { return _size; } long num_elements() const /*! ensures - if num_dims() > 0, returns the product of all dimensions, i.e. the total number of elements - if num_dims() == 0, returns 0 !*/ { return _num_elements; } const_reference operator[](size_t index) const /*! requires - index < num_dims() ensures - returns a const reference to the dimension at position index !*/ { DLIB_ASSERT(index < _size, "index " << index << " out of range [0," << _size << ")"); return _dims[index]; } const_reference back() const /*! requires - num_dims() > 0 ensures - returns a const reference to (*this)[num_dims()-1] !*/ { DLIB_ASSERT(_size > 0, "object is empty"); return _dims[_size-1]; } const_iterator begin() const /*! ensures - returns a const iterator that points to the first dimension in this container or end() if the array is empty. !*/ { return _dims.begin(); } const_iterator end() const /*! ensures - returns a const iterator that points to one past the end of the container. !*/ { return _dims.begin() + _size; } bool operator==(const fft_size& other) const /*! ensures - returns true if two fft_size objects have same size and same dimensions, i.e. if they have identical states !*/ { return this->_size == other._size && std::equal(begin(), end(), other.begin()); } private: size_t _size = 0; size_t _num_elements = 0; container_type _dims; }; inline dlib::uint32 hash( const fft_size& item, dlib::uint32 seed = 0) { seed = dlib::hash((dlib::uint64)item.num_dims(), seed); seed = std::accumulate(item.begin(), item.end(), seed, [](dlib::uint32 seed, long next) { return dlib::hash((dlib::uint64)next, seed); }); return seed; } /*! ensures - returns a 32bit hash of the data stored in item. !*/ inline fft_size pop_back(const fft_size& size) { DLIB_ASSERT(size.num_dims() > 0); return fft_size(size.begin(), size.end() - 1); } /*! requires - num_dims.size() > 0 ensures - returns a copy of size with the last dimension removed. !*/ inline fft_size squeeze_ones(const fft_size size) { DLIB_ASSERT(size.num_dims() > 0); fft_size newsize; if (size.num_elements() == 1) { newsize = {1}; } else { fft_size::container_type tmp; auto end = std::copy_if(size.begin(), size.end(), tmp.begin(), [](long dim){return dim != 1;}); newsize = fft_size(tmp.begin(), end); } return newsize; } /*! requires - num_dims.size() > 0 ensures - removes dimensions with values equal to 1, yielding a new fft_size object with the same num_elements() but fewer dimensions !*/ } #endif //DLIB_FFT_SIZE_H