diff --git "a/lib/python3.11/site-packages/mlx/include/mlx/3rdparty/pocketfft.h" "b/lib/python3.11/site-packages/mlx/include/mlx/3rdparty/pocketfft.h" new file mode 100644--- /dev/null +++ "b/lib/python3.11/site-packages/mlx/include/mlx/3rdparty/pocketfft.h" @@ -0,0 +1,3581 @@ +/* +This file is part of pocketfft. + +Copyright (C) 2010-2022 Max-Planck-Society +Copyright (C) 2019-2020 Peter Bell + +For the odd-sized DCT-IV transforms: + Copyright (C) 2003, 2007-14 Matteo Frigo + Copyright (C) 2003, 2007-14 Massachusetts Institute of Technology + +Authors: Martin Reinecke, Peter Bell + +All rights reserved. + +Redistribution and use in source and binary forms, with or without modification, +are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. +* Redistributions in binary form must reproduce the above copyright notice, this + list of conditions and the following disclaimer in the documentation and/or + other materials provided with the distribution. +* Neither the name of the copyright holder nor the names of its contributors may + be used to endorse or promote products derived from this software without + specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR +ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON +ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*/ + +#ifndef POCKETFFT_HDRONLY_H +#define POCKETFFT_HDRONLY_H + +#ifndef __cplusplus +#error This file is C++ and requires a C++ compiler. +#endif + +#if !(__cplusplus >= 201103L || _MSVC_LANG+0L >= 201103L) +#error This file requires at least C++11 support. +#endif + +#ifndef POCKETFFT_CACHE_SIZE +#define POCKETFFT_CACHE_SIZE 0 +#endif + +#include +#include +#include +#include +#include +#include +#include +#if POCKETFFT_CACHE_SIZE!=0 +#include +#include +#endif + +#ifndef POCKETFFT_NO_MULTITHREADING +#include +#include +#include +#include +#include +#include +#include + +#ifdef POCKETFFT_PTHREADS +# include +#endif +#endif + +#if defined(__GNUC__) +#define POCKETFFT_NOINLINE __attribute__((noinline)) +#define POCKETFFT_RESTRICT __restrict__ +#elif defined(_MSC_VER) +#define POCKETFFT_NOINLINE __declspec(noinline) +#define POCKETFFT_RESTRICT __restrict +#else +#define POCKETFFT_NOINLINE +#define POCKETFFT_RESTRICT +#endif + +namespace pocketfft { + +namespace detail { +using std::size_t; +using std::ptrdiff_t; + +// Always use std:: for functions +template T cos(T) = delete; +template T sin(T) = delete; +template T sqrt(T) = delete; + +using shape_t = std::vector; +using stride_t = std::vector; + +constexpr bool FORWARD = true, + BACKWARD = false; + +// only enable vector support for gcc>=5.0 and clang>=5.0 +#ifndef POCKETFFT_NO_VECTORS +#define POCKETFFT_NO_VECTORS +#if defined(__INTEL_COMPILER) +// do nothing. This is necessary because this compiler also sets __GNUC__. +#elif defined(__clang__) +// AppleClang has their own version numbering +#ifdef __apple_build_version__ +# if (__clang_major__ > 9) || (__clang_major__ == 9 && __clang_minor__ >= 1) +# undef POCKETFFT_NO_VECTORS +# endif +#elif __clang_major__ >= 5 +# undef POCKETFFT_NO_VECTORS +#endif +#elif defined(__GNUC__) +#if __GNUC__>=5 +#undef POCKETFFT_NO_VECTORS +#endif +#endif +#endif + +template struct VLEN { static constexpr size_t val=1; }; + +#ifndef POCKETFFT_NO_VECTORS +#if (defined(__AVX512F__)) +template<> struct VLEN { static constexpr size_t val=16; }; +template<> struct VLEN { static constexpr size_t val=8; }; +#elif (defined(__AVX__)) +template<> struct VLEN { static constexpr size_t val=8; }; +template<> struct VLEN { static constexpr size_t val=4; }; +#elif (defined(__SSE2__)) +template<> struct VLEN { static constexpr size_t val=4; }; +template<> struct VLEN { static constexpr size_t val=2; }; +#elif (defined(__VSX__)) +template<> struct VLEN { static constexpr size_t val=4; }; +template<> struct VLEN { static constexpr size_t val=2; }; +#elif (defined(__ARM_NEON__) || defined(__ARM_NEON)) +template<> struct VLEN { static constexpr size_t val=4; }; +template<> struct VLEN { static constexpr size_t val=2; }; +#else +#define POCKETFFT_NO_VECTORS +#endif +#endif + +// the __MINGW32__ part in the conditional below works around the problem that +// the standard C++ library on Windows does not provide aligned_alloc() even +// though the MinGW compiler and MSVC may advertise C++17 compliance. +#if (__cplusplus >= 201703L) && (!defined(__MINGW32__)) && (!defined(_MSC_VER)) +inline void *aligned_alloc(size_t align, size_t size) + { + // aligned_alloc() requires that the requested size is a multiple of "align" + void *ptr = ::aligned_alloc(align,(size+align-1)&(~(align-1))); + if (!ptr) throw std::bad_alloc(); + return ptr; + } +inline void aligned_dealloc(void *ptr) + { free(ptr); } +#else // portable emulation +inline void *aligned_alloc(size_t align, size_t size) + { + align = std::max(align, alignof(max_align_t)); + void *ptr = malloc(size+align); + if (!ptr) throw std::bad_alloc(); + void *res = reinterpret_cast + ((reinterpret_cast(ptr) & ~(uintptr_t(align-1))) + uintptr_t(align)); + (reinterpret_cast(res))[-1] = ptr; + return res; + } +inline void aligned_dealloc(void *ptr) + { if (ptr) free((reinterpret_cast(ptr))[-1]); } +#endif + +template class arr + { + private: + T *p; + size_t sz; + +#if defined(POCKETFFT_NO_VECTORS) + static T *ralloc(size_t num) + { + if (num==0) return nullptr; + void *res = malloc(num*sizeof(T)); + if (!res) throw std::bad_alloc(); + return reinterpret_cast(res); + } + static void dealloc(T *ptr) + { free(ptr); } +#else + static T *ralloc(size_t num) + { + if (num==0) return nullptr; + void *ptr = aligned_alloc(64, num*sizeof(T)); + return static_cast(ptr); + } + static void dealloc(T *ptr) + { aligned_dealloc(ptr); } +#endif + + public: + arr() : p(0), sz(0) {} + arr(size_t n) : p(ralloc(n)), sz(n) {} + arr(arr &&other) + : p(other.p), sz(other.sz) + { other.p=nullptr; other.sz=0; } + ~arr() { dealloc(p); } + + void resize(size_t n) + { + if (n==sz) return; + dealloc(p); + p = ralloc(n); + sz = n; + } + + T &operator[](size_t idx) { return p[idx]; } + const T &operator[](size_t idx) const { return p[idx]; } + + T *data() { return p; } + const T *data() const { return p; } + + size_t size() const { return sz; } + }; + +template struct cmplx { + T r, i; + cmplx() {} + cmplx(T r_, T i_) : r(r_), i(i_) {} + void Set(T r_, T i_) { r=r_; i=i_; } + void Set(T r_) { r=r_; i=T(0); } + cmplx &operator+= (const cmplx &other) + { r+=other.r; i+=other.i; return *this; } + templatecmplx &operator*= (T2 other) + { r*=other; i*=other; return *this; } + templatecmplx &operator*= (const cmplx &other) + { + T tmp = r*other.r - i*other.i; + i = r*other.i + i*other.r; + r = tmp; + return *this; + } + templatecmplx &operator+= (const cmplx &other) + { r+=other.r; i+=other.i; return *this; } + templatecmplx &operator-= (const cmplx &other) + { r-=other.r; i-=other.i; return *this; } + template auto operator* (const T2 &other) const + -> cmplx + { return {r*other, i*other}; } + template auto operator+ (const cmplx &other) const + -> cmplx + { return {r+other.r, i+other.i}; } + template auto operator- (const cmplx &other) const + -> cmplx + { return {r-other.r, i-other.i}; } + template auto operator* (const cmplx &other) const + -> cmplx + { return {r*other.r-i*other.i, r*other.i + i*other.r}; } + template auto special_mul (const cmplx &other) const + -> cmplx + { + using Tres = cmplx; + return fwd ? Tres(r*other.r+i*other.i, i*other.r-r*other.i) + : Tres(r*other.r-i*other.i, r*other.i+i*other.r); + } +}; +template inline void PM(T &a, T &b, T c, T d) + { a=c+d; b=c-d; } +template inline void PMINPLACE(T &a, T &b) + { T t = a; a+=b; b=t-b; } +template inline void MPINPLACE(T &a, T &b) + { T t = a; a-=b; b=t+b; } +template cmplx conj(const cmplx &a) + { return {a.r, -a.i}; } +template void special_mul (const cmplx &v1, const cmplx &v2, cmplx &res) + { + res = fwd ? cmplx(v1.r*v2.r+v1.i*v2.i, v1.i*v2.r-v1.r*v2.i) + : cmplx(v1.r*v2.r-v1.i*v2.i, v1.r*v2.i+v1.i*v2.r); + } + +template void ROT90(cmplx &a) + { auto tmp_=a.r; a.r=-a.i; a.i=tmp_; } +template void ROTX90(cmplx &a) + { auto tmp_= fwd ? -a.r : a.r; a.r = fwd ? a.i : -a.i; a.i=tmp_; } + +// +// twiddle factor section +// +template class sincos_2pibyn + { + private: + using Thigh = typename std::conditional<(sizeof(T)>sizeof(double)), T, double>::type; + size_t N, mask, shift; + arr> v1, v2; + + static cmplx calc(size_t x, size_t n, Thigh ang) + { + x<<=3; + if (x<4*n) // first half + { + if (x<2*n) // first quadrant + { + if (x(std::cos(Thigh(x)*ang), std::sin(Thigh(x)*ang)); + return cmplx(std::sin(Thigh(2*n-x)*ang), std::cos(Thigh(2*n-x)*ang)); + } + else // second quadrant + { + x-=2*n; + if (x(-std::sin(Thigh(x)*ang), std::cos(Thigh(x)*ang)); + return cmplx(-std::cos(Thigh(2*n-x)*ang), std::sin(Thigh(2*n-x)*ang)); + } + } + else + { + x=8*n-x; + if (x<2*n) // third quadrant + { + if (x(std::cos(Thigh(x)*ang), -std::sin(Thigh(x)*ang)); + return cmplx(std::sin(Thigh(2*n-x)*ang), -std::cos(Thigh(2*n-x)*ang)); + } + else // fourth quadrant + { + x-=2*n; + if (x(-std::sin(Thigh(x)*ang), -std::cos(Thigh(x)*ang)); + return cmplx(-std::cos(Thigh(2*n-x)*ang), -std::sin(Thigh(2*n-x)*ang)); + } + } + } + + public: + POCKETFFT_NOINLINE sincos_2pibyn(size_t n) + : N(n) + { + constexpr auto pi = 3.141592653589793238462643383279502884197L; + Thigh ang = Thigh(0.25L*pi/n); + size_t nval = (n+2)/2; + shift = 1; + while((size_t(1)< operator[](size_t idx) const + { + if (2*idx<=N) + { + auto x1=v1[idx&mask], x2=v2[idx>>shift]; + return cmplx(T(x1.r*x2.r-x1.i*x2.i), T(x1.r*x2.i+x1.i*x2.r)); + } + idx = N-idx; + auto x1=v1[idx&mask], x2=v2[idx>>shift]; + return cmplx(T(x1.r*x2.r-x1.i*x2.i), -T(x1.r*x2.i+x1.i*x2.r)); + } + }; + +struct util // hack to avoid duplicate symbols + { + static POCKETFFT_NOINLINE size_t largest_prime_factor (size_t n) + { + size_t res=1; + while ((n&1)==0) + { res=2; n>>=1; } + for (size_t x=3; x*x<=n; x+=2) + while ((n%x)==0) + { res=x; n/=x; } + if (n>1) res=n; + return res; + } + + static POCKETFFT_NOINLINE double cost_guess (size_t n) + { + constexpr double lfp=1.1; // penalty for non-hardcoded larger factors + size_t ni=n; + double result=0.; + while ((n&1)==0) + { result+=2; n>>=1; } + for (size_t x=3; x*x<=n; x+=2) + while ((n%x)==0) + { + result+= (x<=5) ? double(x) : lfp*double(x); // penalize larger prime factors + n/=x; + } + if (n>1) result+=(n<=5) ? double(n) : lfp*double(n); + return result*double(ni); + } + + /* returns the smallest composite of 2, 3, 5, 7 and 11 which is >= n */ + static POCKETFFT_NOINLINE size_t good_size_cmplx(size_t n) + { + if (n<=12) return n; + + size_t bestfac=2*n; + for (size_t f11=1; f11n) + { + if (x>=1; + } + else + return n; + } + } + return bestfac; + } + + /* returns the smallest composite of 2, 3, 5 which is >= n */ + static POCKETFFT_NOINLINE size_t good_size_real(size_t n) + { + if (n<=6) return n; + + size_t bestfac=2*n; + for (size_t f5=1; f5n) + { + if (x>=1; + } + else + return n; + } + } + return bestfac; + } + + static size_t prod(const shape_t &shape) + { + size_t res=1; + for (auto sz: shape) + res*=sz; + return res; + } + + static POCKETFFT_NOINLINE void sanity_check(const shape_t &shape, + const stride_t &stride_in, const stride_t &stride_out, bool inplace) + { + auto ndim = shape.size(); + if (ndim<1) throw std::runtime_error("ndim must be >= 1"); + if ((stride_in.size()!=ndim) || (stride_out.size()!=ndim)) + throw std::runtime_error("stride dimension mismatch"); + if (inplace && (stride_in!=stride_out)) + throw std::runtime_error("stride mismatch"); + } + + static POCKETFFT_NOINLINE void sanity_check(const shape_t &shape, + const stride_t &stride_in, const stride_t &stride_out, bool inplace, + const shape_t &axes) + { + sanity_check(shape, stride_in, stride_out, inplace); + auto ndim = shape.size(); + shape_t tmp(ndim,0); + for (auto ax : axes) + { + if (ax>=ndim) throw std::invalid_argument("bad axis number"); + if (++tmp[ax]>1) throw std::invalid_argument("axis specified repeatedly"); + } + } + + static POCKETFFT_NOINLINE void sanity_check(const shape_t &shape, + const stride_t &stride_in, const stride_t &stride_out, bool inplace, + size_t axis) + { + sanity_check(shape, stride_in, stride_out, inplace); + if (axis>=shape.size()) throw std::invalid_argument("bad axis number"); + } + +#ifdef POCKETFFT_NO_MULTITHREADING + static size_t thread_count (size_t /*nthreads*/, const shape_t &/*shape*/, + size_t /*axis*/, size_t /*vlen*/) + { return 1; } +#else + static size_t thread_count (size_t nthreads, const shape_t &shape, + size_t axis, size_t vlen) + { + if (nthreads==1) return 1; + size_t size = prod(shape); + size_t parallel = size / (shape[axis] * vlen); + if (shape[axis] < 1000) + parallel /= 4; + size_t max_threads = nthreads == 0 ? + std::thread::hardware_concurrency() : nthreads; + return std::max(size_t(1), std::min(parallel, max_threads)); + } +#endif + }; + +namespace threading { + +#ifdef POCKETFFT_NO_MULTITHREADING + +constexpr inline size_t thread_id() { return 0; } +constexpr inline size_t num_threads() { return 1; } + +template +void thread_map(size_t /* nthreads */, Func f) + { f(); } + +#else + +inline size_t &thread_id() + { + static thread_local size_t thread_id_=0; + return thread_id_; + } +inline size_t &num_threads() + { + static thread_local size_t num_threads_=1; + return num_threads_; + } +static const size_t max_threads = std::max(1u, std::thread::hardware_concurrency()); + +class latch + { + std::atomic num_left_; + std::mutex mut_; + std::condition_variable completed_; + using lock_t = std::unique_lock; + + public: + latch(size_t n): num_left_(n) {} + + void count_down() + { + lock_t lock(mut_); + if (--num_left_) + return; + completed_.notify_all(); + } + + void wait() + { + lock_t lock(mut_); + completed_.wait(lock, [this]{ return is_ready(); }); + } + bool is_ready() { return num_left_ == 0; } + }; + +template class concurrent_queue + { + std::queue q_; + std::mutex mut_; + std::atomic size_; + using lock_t = std::lock_guard; + + public: + + void push(T val) + { + lock_t lock(mut_); + ++size_; + q_.push(std::move(val)); + } + + bool try_pop(T &val) + { + if (size_ == 0) return false; + lock_t lock(mut_); + // Queue might have been emptied while we acquired the lock + if (q_.empty()) return false; + + val = std::move(q_.front()); + --size_; + q_.pop(); + return true; + } + + bool empty() const { return size_==0; } + }; + +// C++ allocator with support for over-aligned types +template struct aligned_allocator + { + using value_type = T; + template + aligned_allocator(const aligned_allocator&) {} + aligned_allocator() = default; + + T *allocate(size_t n) + { + void* mem = aligned_alloc(alignof(T), n*sizeof(T)); + return static_cast(mem); + } + + void deallocate(T *p, size_t /*n*/) + { aligned_dealloc(p); } + }; + +class thread_pool + { + // A reasonable guess, probably close enough for most hardware + static constexpr size_t cache_line_size = 64; + struct alignas(cache_line_size) worker + { + std::thread thread; + std::condition_variable work_ready; + std::mutex mut; + std::atomic_flag busy_flag = ATOMIC_FLAG_INIT; + std::function work; + + void worker_main( + std::atomic &shutdown_flag, + std::atomic &unscheduled_tasks, + concurrent_queue> &overflow_work) + { + using lock_t = std::unique_lock; + bool expect_work = true; + while (!shutdown_flag || expect_work) + { + std::function local_work; + if (expect_work || unscheduled_tasks == 0) + { + lock_t lock(mut); + // Wait until there is work to be executed + work_ready.wait(lock, [&]{ return (work || shutdown_flag); }); + local_work.swap(work); + expect_work = false; + } + + bool marked_busy = false; + if (local_work) + { + marked_busy = true; + local_work(); + } + + if (!overflow_work.empty()) + { + if (!marked_busy && busy_flag.test_and_set()) + { + expect_work = true; + continue; + } + marked_busy = true; + + while (overflow_work.try_pop(local_work)) + { + --unscheduled_tasks; + local_work(); + } + } + + if (marked_busy) busy_flag.clear(); + } + } + }; + + concurrent_queue> overflow_work_; + std::mutex mut_; + std::vector> workers_; + std::atomic shutdown_; + std::atomic unscheduled_tasks_; + using lock_t = std::lock_guard; + + void create_threads() + { + lock_t lock(mut_); + size_t nthreads=workers_.size(); + for (size_t i=0; ibusy_flag.clear(); + worker->work = nullptr; + worker->thread = std::thread([worker, this] + { + worker->worker_main(shutdown_, unscheduled_tasks_, overflow_work_); + }); + } + catch (...) + { + shutdown_locked(); + throw; + } + } + } + + void shutdown_locked() + { + shutdown_ = true; + for (auto &worker : workers_) + worker.work_ready.notify_all(); + + for (auto &worker : workers_) + if (worker.thread.joinable()) + worker.thread.join(); + } + + public: + explicit thread_pool(size_t nthreads): + workers_(nthreads) + { create_threads(); } + + thread_pool(): thread_pool(max_threads) {} + + ~thread_pool() { shutdown(); } + + void submit(std::function work) + { + lock_t lock(mut_); + if (shutdown_) + throw std::runtime_error("Work item submitted after shutdown"); + + ++unscheduled_tasks_; + + // First check for any idle workers and wake those + for (auto &worker : workers_) + if (!worker.busy_flag.test_and_set()) + { + --unscheduled_tasks_; + { + lock_t lock(worker.mut); + worker.work = std::move(work); + } + worker.work_ready.notify_one(); + return; + } + + // If no workers were idle, push onto the overflow queue for later + overflow_work_.push(std::move(work)); + } + + void shutdown() + { + lock_t lock(mut_); + shutdown_locked(); + } + + void restart() + { + shutdown_ = false; + create_threads(); + } + }; + +inline thread_pool & get_pool() + { + static thread_pool pool; +#ifdef POCKETFFT_PTHREADS + static std::once_flag f; + std::call_once(f, + []{ + pthread_atfork( + +[]{ get_pool().shutdown(); }, // prepare + +[]{ get_pool().restart(); }, // parent + +[]{ get_pool().restart(); } // child + ); + }); +#endif + + return pool; + } + +/** Map a function f over nthreads */ +template +void thread_map(size_t nthreads, Func f) + { + if (nthreads == 0) + nthreads = max_threads; + + if (nthreads == 1) + { f(); return; } + + auto & pool = get_pool(); + latch counter(nthreads); + std::exception_ptr ex; + std::mutex ex_mut; + for (size_t i=0; i lock(ex_mut); + ex = std::current_exception(); + } + counter.count_down(); + }); + } + counter.wait(); + if (ex) + std::rethrow_exception(ex); + } + +#endif + +} + +// +// complex FFTPACK transforms +// + +template class cfftp + { + private: + struct fctdata + { + size_t fct; + cmplx *tw, *tws; + }; + + size_t length; + arr> mem; + std::vector fact; + + void add_factor(size_t factor) + { fact.push_back({factor, nullptr, nullptr}); } + +template void pass2 (size_t ido, size_t l1, + const T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch, + const cmplx * POCKETFFT_RESTRICT wa) const + { + auto CH = [ch,ido,l1](size_t a, size_t b, size_t c) -> T& + { return ch[a+ido*(b+l1*c)]; }; + auto CC = [cc,ido](size_t a, size_t b, size_t c) -> const T& + { return cc[a+ido*(b+2*c)]; }; + auto WA = [wa, ido](size_t x, size_t i) + { return wa[i-1+x*(ido-1)]; }; + + if (ido==1) + for (size_t k=0; k(CC(i,0,k)-CC(i,1,k),WA(0,i),CH(i,k,1)); + } + } + } + +#define POCKETFFT_PREP3(idx) \ + T t0 = CC(idx,0,k), t1, t2; \ + PM (t1,t2,CC(idx,1,k),CC(idx,2,k)); \ + CH(idx,k,0)=t0+t1; +#define POCKETFFT_PARTSTEP3a(u1,u2,twr,twi) \ + { \ + T ca=t0+t1*twr; \ + T cb{-t2.i*twi, t2.r*twi}; \ + PM(CH(0,k,u1),CH(0,k,u2),ca,cb) ;\ + } +#define POCKETFFT_PARTSTEP3b(u1,u2,twr,twi) \ + { \ + T ca=t0+t1*twr; \ + T cb{-t2.i*twi, t2.r*twi}; \ + special_mul(ca+cb,WA(u1-1,i),CH(i,k,u1)); \ + special_mul(ca-cb,WA(u2-1,i),CH(i,k,u2)); \ + } +template void pass3 (size_t ido, size_t l1, + const T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch, + const cmplx * POCKETFFT_RESTRICT wa) const + { + constexpr T0 tw1r=-0.5, + tw1i= (fwd ? -1: 1) * T0(0.8660254037844386467637231707529362L); + + auto CH = [ch,ido,l1](size_t a, size_t b, size_t c) -> T& + { return ch[a+ido*(b+l1*c)]; }; + auto CC = [cc,ido](size_t a, size_t b, size_t c) -> const T& + { return cc[a+ido*(b+3*c)]; }; + auto WA = [wa, ido](size_t x, size_t i) + { return wa[i-1+x*(ido-1)]; }; + + if (ido==1) + for (size_t k=0; k void pass4 (size_t ido, size_t l1, + const T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch, + const cmplx * POCKETFFT_RESTRICT wa) const + { + auto CH = [ch,ido,l1](size_t a, size_t b, size_t c) -> T& + { return ch[a+ido*(b+l1*c)]; }; + auto CC = [cc,ido](size_t a, size_t b, size_t c) -> const T& + { return cc[a+ido*(b+4*c)]; }; + auto WA = [wa, ido](size_t x, size_t i) + { return wa[i-1+x*(ido-1)]; }; + + if (ido==1) + for (size_t k=0; k(t4); + PM(CH(0,k,0),CH(0,k,2),t2,t3); + PM(CH(0,k,1),CH(0,k,3),t1,t4); + } + else + for (size_t k=0; k(t4); + PM(CH(0,k,0),CH(0,k,2),t2,t3); + PM(CH(0,k,1),CH(0,k,3),t1,t4); + } + for (size_t i=1; i(t4); + CH(i,k,0) = t2+t3; + special_mul(t1+t4,WA(0,i),CH(i,k,1)); + special_mul(t2-t3,WA(1,i),CH(i,k,2)); + special_mul(t1-t4,WA(2,i),CH(i,k,3)); + } + } + } + +#define POCKETFFT_PREP5(idx) \ + T t0 = CC(idx,0,k), t1, t2, t3, t4; \ + PM (t1,t4,CC(idx,1,k),CC(idx,4,k)); \ + PM (t2,t3,CC(idx,2,k),CC(idx,3,k)); \ + CH(idx,k,0).r=t0.r+t1.r+t2.r; \ + CH(idx,k,0).i=t0.i+t1.i+t2.i; + +#define POCKETFFT_PARTSTEP5a(u1,u2,twar,twbr,twai,twbi) \ + { \ + T ca,cb; \ + ca.r=t0.r+twar*t1.r+twbr*t2.r; \ + ca.i=t0.i+twar*t1.i+twbr*t2.i; \ + cb.i=twai*t4.r twbi*t3.r; \ + cb.r=-(twai*t4.i twbi*t3.i); \ + PM(CH(0,k,u1),CH(0,k,u2),ca,cb); \ + } + +#define POCKETFFT_PARTSTEP5b(u1,u2,twar,twbr,twai,twbi) \ + { \ + T ca,cb,da,db; \ + ca.r=t0.r+twar*t1.r+twbr*t2.r; \ + ca.i=t0.i+twar*t1.i+twbr*t2.i; \ + cb.i=twai*t4.r twbi*t3.r; \ + cb.r=-(twai*t4.i twbi*t3.i); \ + special_mul(ca+cb,WA(u1-1,i),CH(i,k,u1)); \ + special_mul(ca-cb,WA(u2-1,i),CH(i,k,u2)); \ + } +template void pass5 (size_t ido, size_t l1, + const T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch, + const cmplx * POCKETFFT_RESTRICT wa) const + { + constexpr T0 tw1r= T0(0.3090169943749474241022934171828191L), + tw1i= (fwd ? -1: 1) * T0(0.9510565162951535721164393333793821L), + tw2r= T0(-0.8090169943749474241022934171828191L), + tw2i= (fwd ? -1: 1) * T0(0.5877852522924731291687059546390728L); + + auto CH = [ch,ido,l1](size_t a, size_t b, size_t c) -> T& + { return ch[a+ido*(b+l1*c)]; }; + auto CC = [cc,ido](size_t a, size_t b, size_t c) -> const T& + { return cc[a+ido*(b+5*c)]; }; + auto WA = [wa, ido](size_t x, size_t i) + { return wa[i-1+x*(ido-1)]; }; + + if (ido==1) + for (size_t k=0; k(da,WA(u1-1,i),CH(i,k,u1)); \ + special_mul(db,WA(u2-1,i),CH(i,k,u2)); \ + } + +template void pass7(size_t ido, size_t l1, + const T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch, + const cmplx * POCKETFFT_RESTRICT wa) const + { + constexpr T0 tw1r= T0(0.6234898018587335305250048840042398L), + tw1i= (fwd ? -1 : 1) * T0(0.7818314824680298087084445266740578L), + tw2r= T0(-0.2225209339563144042889025644967948L), + tw2i= (fwd ? -1 : 1) * T0(0.9749279121818236070181316829939312L), + tw3r= T0(-0.9009688679024191262361023195074451L), + tw3i= (fwd ? -1 : 1) * T0(0.433883739117558120475768332848359L); + + auto CH = [ch,ido,l1](size_t a, size_t b, size_t c) -> T& + { return ch[a+ido*(b+l1*c)]; }; + auto CC = [cc,ido](size_t a, size_t b, size_t c) -> const T& + { return cc[a+ido*(b+7*c)]; }; + auto WA = [wa, ido](size_t x, size_t i) + { return wa[i-1+x*(ido-1)]; }; + + if (ido==1) + for (size_t k=0; k void ROTX45(T &a) const + { + constexpr T0 hsqt2=T0(0.707106781186547524400844362104849L); + if (fwd) + { auto tmp_=a.r; a.r=hsqt2*(a.r+a.i); a.i=hsqt2*(a.i-tmp_); } + else + { auto tmp_=a.r; a.r=hsqt2*(a.r-a.i); a.i=hsqt2*(a.i+tmp_); } + } +template void ROTX135(T &a) const + { + constexpr T0 hsqt2=T0(0.707106781186547524400844362104849L); + if (fwd) + { auto tmp_=a.r; a.r=hsqt2*(a.i-a.r); a.i=hsqt2*(-tmp_-a.i); } + else + { auto tmp_=a.r; a.r=hsqt2*(-a.r-a.i); a.i=hsqt2*(tmp_-a.i); } + } + +template void pass8 (size_t ido, size_t l1, + const T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch, + const cmplx * POCKETFFT_RESTRICT wa) const + { + auto CH = [ch,ido,l1](size_t a, size_t b, size_t c) -> T& + { return ch[a+ido*(b+l1*c)]; }; + auto CC = [cc,ido](size_t a, size_t b, size_t c) -> const T& + { return cc[a+ido*(b+8*c)]; }; + auto WA = [wa, ido](size_t x, size_t i) + { return wa[i-1+x*(ido-1)]; }; + + if (ido==1) + for (size_t k=0; k(a3); + + ROTX90(a7); + PMINPLACE(a5,a7); + ROTX45(a5); + ROTX135(a7); + + PM(a0,a4,CC(0,0,k),CC(0,4,k)); + PM(a2,a6,CC(0,2,k),CC(0,6,k)); + PM(CH(0,k,0),CH(0,k,4),a0+a2,a1); + PM(CH(0,k,2),CH(0,k,6),a0-a2,a3); + ROTX90(a6); + PM(CH(0,k,1),CH(0,k,5),a4+a6,a5); + PM(CH(0,k,3),CH(0,k,7),a4-a6,a7); + } + else + for (size_t k=0; k(a3); + + ROTX90(a7); + PMINPLACE(a5,a7); + ROTX45(a5); + ROTX135(a7); + + PM(a0,a4,CC(0,0,k),CC(0,4,k)); + PM(a2,a6,CC(0,2,k),CC(0,6,k)); + PM(CH(0,k,0),CH(0,k,4),a0+a2,a1); + PM(CH(0,k,2),CH(0,k,6),a0-a2,a3); + ROTX90(a6); + PM(CH(0,k,1),CH(0,k,5),a4+a6,a5); + PM(CH(0,k,3),CH(0,k,7),a4-a6,a7); + } + for (size_t i=1; i(a7); + PMINPLACE(a1,a3); + ROTX90(a3); + PMINPLACE(a5,a7); + ROTX45(a5); + ROTX135(a7); + PM(a0,a4,CC(i,0,k),CC(i,4,k)); + PM(a2,a6,CC(i,2,k),CC(i,6,k)); + PMINPLACE(a0,a2); + CH(i,k,0) = a0+a1; + special_mul(a0-a1,WA(3,i),CH(i,k,4)); + special_mul(a2+a3,WA(1,i),CH(i,k,2)); + special_mul(a2-a3,WA(5,i),CH(i,k,6)); + ROTX90(a6); + PMINPLACE(a4,a6); + special_mul(a4+a5,WA(0,i),CH(i,k,1)); + special_mul(a4-a5,WA(4,i),CH(i,k,5)); + special_mul(a6+a7,WA(2,i),CH(i,k,3)); + special_mul(a6-a7,WA(6,i),CH(i,k,7)); + } + } + } + + +#define POCKETFFT_PREP11(idx) \ + T t1 = CC(idx,0,k), t2, t3, t4, t5, t6, t7, t8, t9, t10, t11; \ + PM (t2,t11,CC(idx,1,k),CC(idx,10,k)); \ + PM (t3,t10,CC(idx,2,k),CC(idx, 9,k)); \ + PM (t4,t9 ,CC(idx,3,k),CC(idx, 8,k)); \ + PM (t5,t8 ,CC(idx,4,k),CC(idx, 7,k)); \ + PM (t6,t7 ,CC(idx,5,k),CC(idx, 6,k)); \ + CH(idx,k,0).r=t1.r+t2.r+t3.r+t4.r+t5.r+t6.r; \ + CH(idx,k,0).i=t1.i+t2.i+t3.i+t4.i+t5.i+t6.i; + +#define POCKETFFT_PARTSTEP11a0(u1,u2,x1,x2,x3,x4,x5,y1,y2,y3,y4,y5,out1,out2) \ + { \ + T ca = t1 + t2*x1 + t3*x2 + t4*x3 + t5*x4 +t6*x5, \ + cb; \ + cb.i=y1*t11.r y2*t10.r y3*t9.r y4*t8.r y5*t7.r; \ + cb.r=-(y1*t11.i y2*t10.i y3*t9.i y4*t8.i y5*t7.i ); \ + PM(out1,out2,ca,cb); \ + } +#define POCKETFFT_PARTSTEP11a(u1,u2,x1,x2,x3,x4,x5,y1,y2,y3,y4,y5) \ + POCKETFFT_PARTSTEP11a0(u1,u2,x1,x2,x3,x4,x5,y1,y2,y3,y4,y5,CH(0,k,u1),CH(0,k,u2)) +#define POCKETFFT_PARTSTEP11(u1,u2,x1,x2,x3,x4,x5,y1,y2,y3,y4,y5) \ + { \ + T da,db; \ + POCKETFFT_PARTSTEP11a0(u1,u2,x1,x2,x3,x4,x5,y1,y2,y3,y4,y5,da,db) \ + special_mul(da,WA(u1-1,i),CH(i,k,u1)); \ + special_mul(db,WA(u2-1,i),CH(i,k,u2)); \ + } + +template void pass11 (size_t ido, size_t l1, + const T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch, + const cmplx * POCKETFFT_RESTRICT wa) const + { + constexpr T0 tw1r= T0(0.8412535328311811688618116489193677L), + tw1i= (fwd ? -1 : 1) * T0(0.5406408174555975821076359543186917L), + tw2r= T0(0.4154150130018864255292741492296232L), + tw2i= (fwd ? -1 : 1) * T0(0.9096319953545183714117153830790285L), + tw3r= T0(-0.1423148382732851404437926686163697L), + tw3i= (fwd ? -1 : 1) * T0(0.9898214418809327323760920377767188L), + tw4r= T0(-0.6548607339452850640569250724662936L), + tw4i= (fwd ? -1 : 1) * T0(0.7557495743542582837740358439723444L), + tw5r= T0(-0.9594929736144973898903680570663277L), + tw5i= (fwd ? -1 : 1) * T0(0.2817325568414296977114179153466169L); + + auto CH = [ch,ido,l1](size_t a, size_t b, size_t c) -> T& + { return ch[a+ido*(b+l1*c)]; }; + auto CC = [cc,ido](size_t a, size_t b, size_t c) -> const T& + { return cc[a+ido*(b+11*c)]; }; + auto WA = [wa, ido](size_t x, size_t i) + { return wa[i-1+x*(ido-1)]; }; + + if (ido==1) + for (size_t k=0; k void passg (size_t ido, size_t ip, + size_t l1, T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch, + const cmplx * POCKETFFT_RESTRICT wa, + const cmplx * POCKETFFT_RESTRICT csarr) const + { + const size_t cdim=ip; + size_t ipph = (ip+1)/2; + size_t idl1 = ido*l1; + + auto CH = [ch,ido,l1](size_t a, size_t b, size_t c) -> T& + { return ch[a+ido*(b+l1*c)]; }; + auto CC = [cc,ido,cdim](size_t a, size_t b, size_t c) -> const T& + { return cc[a+ido*(b+cdim*c)]; }; + auto CX = [cc, ido, l1](size_t a, size_t b, size_t c) -> T& + { return cc[a+ido*(b+l1*c)]; }; + auto CX2 = [cc, idl1](size_t a, size_t b) -> T& + { return cc[a+idl1*b]; }; + auto CH2 = [ch, idl1](size_t a, size_t b) -> const T& + { return ch[a+idl1*b]; }; + + arr> wal(ip); + wal[0] = cmplx(1., 0.); + for (size_t i=1; i(csarr[i].r,fwd ? -csarr[i].i : csarr[i].i); + + for (size_t k=0; kip) iwal-=ip; + cmplx xwal=wal[iwal]; + iwal+=l; if (iwal>ip) iwal-=ip; + cmplx xwal2=wal[iwal]; + for (size_t ik=0; ikip) iwal-=ip; + cmplx xwal=wal[iwal]; + for (size_t ik=0; ik(x1,wa[idij],CX(i,k,j)); + idij=(jc-1)*(ido-1)+i-1; + special_mul(x2,wa[idij],CX(i,k,jc)); + } + } + } + } + +template void pass_all(T c[], T0 fct) const + { + if (length==1) { c[0]*=fct; return; } + size_t l1=1; + arr ch(length); + T *p1=c, *p2=ch.data(); + + for(size_t k1=0; k1 (ido, l1, p1, p2, fact[k1].tw); + else if(ip==8) + pass8(ido, l1, p1, p2, fact[k1].tw); + else if(ip==2) + pass2(ido, l1, p1, p2, fact[k1].tw); + else if(ip==3) + pass3 (ido, l1, p1, p2, fact[k1].tw); + else if(ip==5) + pass5 (ido, l1, p1, p2, fact[k1].tw); + else if(ip==7) + pass7 (ido, l1, p1, p2, fact[k1].tw); + else if(ip==11) + pass11 (ido, l1, p1, p2, fact[k1].tw); + else + { + passg(ido, ip, l1, p1, p2, fact[k1].tw, fact[k1].tws); + std::swap(p1,p2); + } + std::swap(p1,p2); + l1=l2; + } + if (p1!=c) + { + if (fct!=1.) + for (size_t i=0; i void exec(T c[], T0 fct, bool fwd) const + { fwd ? pass_all(c, fct) : pass_all(c, fct); } + + private: + POCKETFFT_NOINLINE void factorize() + { + size_t len=length; + while ((len&7)==0) + { add_factor(8); len>>=3; } + while ((len&3)==0) + { add_factor(4); len>>=2; } + if ((len&1)==0) + { + len>>=1; + // factor 2 should be at the front of the factor list + add_factor(2); + std::swap(fact[0].fct, fact.back().fct); + } + for (size_t divisor=3; divisor*divisor<=len; divisor+=2) + while ((len%divisor)==0) + { + add_factor(divisor); + len/=divisor; + } + if (len>1) add_factor(len); + } + + size_t twsize() const + { + size_t twsize=0, l1=1; + for (size_t k=0; k11) + twsize+=ip; + l1*=ip; + } + return twsize; + } + + void comp_twiddle() + { + sincos_2pibyn twiddle(length); + size_t l1=1; + size_t memofs=0; + for (size_t k=0; k11) + { + fact[k].tws=mem.data()+memofs; + memofs+=ip; + for (size_t j=0; j class rfftp + { + private: + struct fctdata + { + size_t fct; + T0 *tw, *tws; + }; + + size_t length; + arr mem; + std::vector fact; + + void add_factor(size_t factor) + { fact.push_back({factor, nullptr, nullptr}); } + +/* (a+ib) = conj(c+id) * (e+if) */ +template inline void MULPM + (T1 &a, T1 &b, T2 c, T2 d, T3 e, T3 f) const + { a=c*e+d*f; b=c*f-d*e; } + +template void radf2 (size_t ido, size_t l1, + const T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch, + const T0 * POCKETFFT_RESTRICT wa) const + { + auto WA = [wa,ido](size_t x, size_t i) { return wa[i+x*(ido-1)]; }; + auto CC = [cc,ido,l1](size_t a, size_t b, size_t c) -> const T& + { return cc[a+ido*(b+l1*c)]; }; + auto CH = [ch,ido](size_t a, size_t b, size_t c) -> T& + { return ch[a+ido*(b+2*c)]; }; + + for (size_t k=0; k void radf3(size_t ido, size_t l1, + const T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch, + const T0 * POCKETFFT_RESTRICT wa) const + { + constexpr T0 taur=-0.5, taui=T0(0.8660254037844386467637231707529362L); + + auto WA = [wa,ido](size_t x, size_t i) { return wa[i+x*(ido-1)]; }; + auto CC = [cc,ido,l1](size_t a, size_t b, size_t c) -> const T& + { return cc[a+ido*(b+l1*c)]; }; + auto CH = [ch,ido](size_t a, size_t b, size_t c) -> T& + { return ch[a+ido*(b+3*c)]; }; + + for (size_t k=0; k void radf4(size_t ido, size_t l1, + const T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch, + const T0 * POCKETFFT_RESTRICT wa) const + { + constexpr T0 hsqt2=T0(0.707106781186547524400844362104849L); + + auto WA = [wa,ido](size_t x, size_t i) { return wa[i+x*(ido-1)]; }; + auto CC = [cc,ido,l1](size_t a, size_t b, size_t c) -> const T& + { return cc[a+ido*(b+l1*c)]; }; + auto CH = [ch,ido](size_t a, size_t b, size_t c) -> T& + { return ch[a+ido*(b+4*c)]; }; + + for (size_t k=0; k void radf5(size_t ido, size_t l1, + const T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch, + const T0 * POCKETFFT_RESTRICT wa) const + { + constexpr T0 tr11= T0(0.3090169943749474241022934171828191L), + ti11= T0(0.9510565162951535721164393333793821L), + tr12= T0(-0.8090169943749474241022934171828191L), + ti12= T0(0.5877852522924731291687059546390728L); + + auto WA = [wa,ido](size_t x, size_t i) { return wa[i+x*(ido-1)]; }; + auto CC = [cc,ido,l1](size_t a, size_t b, size_t c) -> const T& + { return cc[a+ido*(b+l1*c)]; }; + auto CH = [ch,ido](size_t a, size_t b, size_t c) -> T& + { return ch[a+ido*(b+5*c)]; }; + + for (size_t k=0; k void radfg(size_t ido, size_t ip, size_t l1, + T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch, + const T0 * POCKETFFT_RESTRICT wa, const T0 * POCKETFFT_RESTRICT csarr) const + { + const size_t cdim=ip; + size_t ipph=(ip+1)/2; + size_t idl1 = ido*l1; + + auto CC = [cc,ido,cdim](size_t a, size_t b, size_t c) -> T& + { return cc[a+ido*(b+cdim*c)]; }; + auto CH = [ch,ido,l1](size_t a, size_t b, size_t c) -> const T& + { return ch[a+ido*(b+l1*c)]; }; + auto C1 = [cc,ido,l1] (size_t a, size_t b, size_t c) -> T& + { return cc[a+ido*(b+l1*c)]; }; + auto C2 = [cc,idl1] (size_t a, size_t b) -> T& + { return cc[a+idl1*b]; }; + auto CH2 = [ch,idl1] (size_t a, size_t b) -> T& + { return ch[a+idl1*b]; }; + + if (ido>1) + { + for (size_t j=1, jc=ip-1; j=ip) iang-=ip; + T0 ar1=csarr[2*iang], ai1=csarr[2*iang+1]; + iang+=l; if (iang>=ip) iang-=ip; + T0 ar2=csarr[2*iang], ai2=csarr[2*iang+1]; + iang+=l; if (iang>=ip) iang-=ip; + T0 ar3=csarr[2*iang], ai3=csarr[2*iang+1]; + iang+=l; if (iang>=ip) iang-=ip; + T0 ar4=csarr[2*iang], ai4=csarr[2*iang+1]; + for (size_t ik=0; ik=ip) iang-=ip; + T0 ar1=csarr[2*iang], ai1=csarr[2*iang+1]; + iang+=l; if (iang>=ip) iang-=ip; + T0 ar2=csarr[2*iang], ai2=csarr[2*iang+1]; + for (size_t ik=0; ik=ip) iang-=ip; + T0 ar=csarr[2*iang], ai=csarr[2*iang+1]; + for (size_t ik=0; ik void radb2(size_t ido, size_t l1, + const T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch, + const T0 * POCKETFFT_RESTRICT wa) const + { + auto WA = [wa,ido](size_t x, size_t i) { return wa[i+x*(ido-1)]; }; + auto CC = [cc,ido](size_t a, size_t b, size_t c) -> const T& + { return cc[a+ido*(b+2*c)]; }; + auto CH = [ch,ido,l1](size_t a, size_t b, size_t c) -> T& + { return ch[a+ido*(b+l1*c)]; }; + + for (size_t k=0; k void radb3(size_t ido, size_t l1, + const T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch, + const T0 * POCKETFFT_RESTRICT wa) const + { + constexpr T0 taur=-0.5, taui=T0(0.8660254037844386467637231707529362L); + + auto WA = [wa,ido](size_t x, size_t i) { return wa[i+x*(ido-1)]; }; + auto CC = [cc,ido](size_t a, size_t b, size_t c) -> const T& + { return cc[a+ido*(b+3*c)]; }; + auto CH = [ch,ido,l1](size_t a, size_t b, size_t c) -> T& + { return ch[a+ido*(b+l1*c)]; }; + + for (size_t k=0; k void radb4(size_t ido, size_t l1, + const T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch, + const T0 * POCKETFFT_RESTRICT wa) const + { + constexpr T0 sqrt2=T0(1.414213562373095048801688724209698L); + + auto WA = [wa,ido](size_t x, size_t i) { return wa[i+x*(ido-1)]; }; + auto CC = [cc,ido](size_t a, size_t b, size_t c) -> const T& + { return cc[a+ido*(b+4*c)]; }; + auto CH = [ch,ido,l1](size_t a, size_t b, size_t c) -> T& + { return ch[a+ido*(b+l1*c)]; }; + + for (size_t k=0; k void radb5(size_t ido, size_t l1, + const T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch, + const T0 * POCKETFFT_RESTRICT wa) const + { + constexpr T0 tr11= T0(0.3090169943749474241022934171828191L), + ti11= T0(0.9510565162951535721164393333793821L), + tr12= T0(-0.8090169943749474241022934171828191L), + ti12= T0(0.5877852522924731291687059546390728L); + + auto WA = [wa,ido](size_t x, size_t i) { return wa[i+x*(ido-1)]; }; + auto CC = [cc,ido](size_t a, size_t b, size_t c) -> const T& + { return cc[a+ido*(b+5*c)]; }; + auto CH = [ch,ido,l1](size_t a, size_t b, size_t c) -> T& + { return ch[a+ido*(b+l1*c)]; }; + + for (size_t k=0; k void radbg(size_t ido, size_t ip, size_t l1, + T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch, + const T0 * POCKETFFT_RESTRICT wa, const T0 * POCKETFFT_RESTRICT csarr) const + { + const size_t cdim=ip; + size_t ipph=(ip+1)/ 2; + size_t idl1 = ido*l1; + + auto CC = [cc,ido,cdim](size_t a, size_t b, size_t c) -> const T& + { return cc[a+ido*(b+cdim*c)]; }; + auto CH = [ch,ido,l1](size_t a, size_t b, size_t c) -> T& + { return ch[a+ido*(b+l1*c)]; }; + auto C1 = [cc,ido,l1](size_t a, size_t b, size_t c) -> const T& + { return cc[a+ido*(b+l1*c)]; }; + auto C2 = [cc,idl1](size_t a, size_t b) -> T& + { return cc[a+idl1*b]; }; + auto CH2 = [ch,idl1](size_t a, size_t b) -> T& + { return ch[a+idl1*b]; }; + + for (size_t k=0; kip) iang-=ip; + T0 ar1=csarr[2*iang], ai1=csarr[2*iang+1]; + iang+=l; if(iang>ip) iang-=ip; + T0 ar2=csarr[2*iang], ai2=csarr[2*iang+1]; + iang+=l; if(iang>ip) iang-=ip; + T0 ar3=csarr[2*iang], ai3=csarr[2*iang+1]; + iang+=l; if(iang>ip) iang-=ip; + T0 ar4=csarr[2*iang], ai4=csarr[2*iang+1]; + for (size_t ik=0; ikip) iang-=ip; + T0 ar1=csarr[2*iang], ai1=csarr[2*iang+1]; + iang+=l; if(iang>ip) iang-=ip; + T0 ar2=csarr[2*iang], ai2=csarr[2*iang+1]; + for (size_t ik=0; ikip) iang-=ip; + T0 war=csarr[2*iang], wai=csarr[2*iang+1]; + for (size_t ik=0; ik void copy_and_norm(T *c, T *p1, T0 fct) const + { + if (p1!=c) + { + if (fct!=1.) + for (size_t i=0; i void exec(T c[], T0 fct, bool r2hc) const + { + if (length==1) { c[0]*=fct; return; } + size_t nf=fact.size(); + arr ch(length); + T *p1=c, *p2=ch.data(); + + if (r2hc) + for(size_t k1=0, l1=length; k1>=2; } + if ((len%2)==0) + { + len>>=1; + // factor 2 should be at the front of the factor list + add_factor(2); + std::swap(fact[0].fct, fact.back().fct); + } + for (size_t divisor=3; divisor*divisor<=len; divisor+=2) + while ((len%divisor)==0) + { + add_factor(divisor); + len/=divisor; + } + if (len>1) add_factor(len); + } + + size_t twsize() const + { + size_t twsz=0, l1=1; + for (size_t k=0; k5) twsz+=2*ip; + l1*=ip; + } + return twsz; + } + + void comp_twiddle() + { + sincos_2pibyn twid(length); + size_t l1=1; + T0 *ptr=mem.data(); + for (size_t k=0; k5) // special factors required by *g functions + { + fact[k].tws=ptr; ptr+=2*ip; + fact[k].tws[0] = 1.; + fact[k].tws[1] = 0.; + for (size_t i=2, ic=2*ip-2; i<=ic; i+=2, ic-=2) + { + fact[k].tws[i ] = twid[i/2*(length/ip)].r; + fact[k].tws[i+1] = twid[i/2*(length/ip)].i; + fact[k].tws[ic] = twid[i/2*(length/ip)].r; + fact[k].tws[ic+1] = -twid[i/2*(length/ip)].i; + } + } + l1*=ip; + } + } + + public: + POCKETFFT_NOINLINE rfftp(size_t length_) + : length(length_) + { + if (length==0) throw std::runtime_error("zero-length FFT requested"); + if (length==1) return; + factorize(); + mem.resize(twsize()); + comp_twiddle(); + } +}; + +// +// complex Bluestein transforms +// + +template class fftblue + { + private: + size_t n, n2; + cfftp plan; + arr> mem; + cmplx *bk, *bkf; + + template void fft(cmplx c[], T0 fct) const + { + arr> akf(n2); + + /* initialize a_k and FFT it */ + for (size_t m=0; m(c[m],bk[m],akf[m]); + auto zero = akf[0]*T0(0); + for (size_t m=n; m(bkf[0]); + for (size_t m=1; m<(n2+1)/2; ++m) + { + akf[m] = akf[m].template special_mul(bkf[m]); + akf[n2-m] = akf[n2-m].template special_mul(bkf[m]); + } + if ((n2&1)==0) + akf[n2/2] = akf[n2/2].template special_mul(bkf[n2/2]); + + /* inverse FFT */ + plan.exec (akf.data(),1.,false); + + /* multiply by b_k */ + for (size_t m=0; m(bk[m])*fct; + } + + public: + POCKETFFT_NOINLINE fftblue(size_t length) + : n(length), n2(util::good_size_cmplx(n*2-1)), plan(n2), mem(n+n2/2+1), + bk(mem.data()), bkf(mem.data()+n) + { + /* initialize b_k */ + sincos_2pibyn tmp(2*n); + bk[0].Set(1, 0); + + size_t coeff=0; + for (size_t m=1; m=2*n) coeff-=2*n; + bk[m] = tmp[coeff]; + } + + /* initialize the zero-padded, Fourier transformed b_k. Add normalisation. */ + arr> tbkf(n2); + T0 xn2 = T0(1)/T0(n2); + tbkf[0] = bk[0]*xn2; + for (size_t m=1; m void exec(cmplx c[], T0 fct, bool fwd) const + { fwd ? fft(c,fct) : fft(c,fct); } + + template void exec_r(T c[], T0 fct, bool fwd) + { + arr> tmp(n); + if (fwd) + { + auto zero = T0(0)*c[0]; + for (size_t m=0; m(tmp.data(),fct); + c[0] = tmp[0].r; + std::copy_n (&tmp[1].r, n-1, &c[1]); + } + else + { + tmp[0].Set(c[0],c[0]*0); + std::copy_n (c+1, n-1, &tmp[1].r); + if ((n&1)==0) tmp[n/2].i=T0(0)*c[0]; + for (size_t m=1; 2*m(tmp.data(),fct); + for (size_t m=0; m class pocketfft_c + { + private: + std::unique_ptr> packplan; + std::unique_ptr> blueplan; + size_t len; + + public: + POCKETFFT_NOINLINE pocketfft_c(size_t length) + : len(length) + { + if (length==0) throw std::runtime_error("zero-length FFT requested"); + size_t tmp = (length<50) ? 0 : util::largest_prime_factor(length); + if (tmp*tmp <= length) + { + packplan=std::unique_ptr>(new cfftp(length)); + return; + } + double comp1 = util::cost_guess(length); + double comp2 = 2*util::cost_guess(util::good_size_cmplx(2*length-1)); + comp2*=1.5; /* fudge factor that appears to give good overall performance */ + if (comp2>(new fftblue(length)); + else + packplan=std::unique_ptr>(new cfftp(length)); + } + + template POCKETFFT_NOINLINE void exec(cmplx c[], T0 fct, bool fwd) const + { packplan ? packplan->exec(c,fct,fwd) : blueplan->exec(c,fct,fwd); } + + size_t length() const { return len; } + }; + +// +// flexible (FFTPACK/Bluestein) real-valued 1D transform +// + +template class pocketfft_r + { + private: + std::unique_ptr> packplan; + std::unique_ptr> blueplan; + size_t len; + + public: + POCKETFFT_NOINLINE pocketfft_r(size_t length) + : len(length) + { + if (length==0) throw std::runtime_error("zero-length FFT requested"); + size_t tmp = (length<50) ? 0 : util::largest_prime_factor(length); + if (tmp*tmp <= length) + { + packplan=std::unique_ptr>(new rfftp(length)); + return; + } + double comp1 = 0.5*util::cost_guess(length); + double comp2 = 2*util::cost_guess(util::good_size_cmplx(2*length-1)); + comp2*=1.5; /* fudge factor that appears to give good overall performance */ + if (comp2>(new fftblue(length)); + else + packplan=std::unique_ptr>(new rfftp(length)); + } + + template POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool fwd) const + { packplan ? packplan->exec(c,fct,fwd) : blueplan->exec_r(c,fct,fwd); } + + size_t length() const { return len; } + }; + + +// +// sine/cosine transforms +// + +template class T_dct1 + { + private: + pocketfft_r fftplan; + + public: + POCKETFFT_NOINLINE T_dct1(size_t length) + : fftplan(2*(length-1)) {} + + template POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool ortho, + int /*type*/, bool /*cosine*/) const + { + constexpr T0 sqrt2=T0(1.414213562373095048801688724209698L); + size_t N=fftplan.length(), n=N/2+1; + if (ortho) + { c[0]*=sqrt2; c[n-1]*=sqrt2; } + arr tmp(N); + tmp[0] = c[0]; + for (size_t i=1; i class T_dst1 + { + private: + pocketfft_r fftplan; + + public: + POCKETFFT_NOINLINE T_dst1(size_t length) + : fftplan(2*(length+1)) {} + + template POCKETFFT_NOINLINE void exec(T c[], T0 fct, + bool /*ortho*/, int /*type*/, bool /*cosine*/) const + { + size_t N=fftplan.length(), n=N/2-1; + arr tmp(N); + tmp[0] = tmp[n+1] = c[0]*0; + for (size_t i=0; i class T_dcst23 + { + private: + pocketfft_r fftplan; + std::vector twiddle; + + public: + POCKETFFT_NOINLINE T_dcst23(size_t length) + : fftplan(length), twiddle(length) + { + sincos_2pibyn tw(4*length); + for (size_t i=0; i POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool ortho, + int type, bool cosine) const + { + constexpr T0 sqrt2=T0(1.414213562373095048801688724209698L); + size_t N=length(); + size_t NS2 = (N+1)/2; + if (type==2) + { + if (!cosine) + for (size_t k=1; k class T_dcst4 + { + private: + size_t N; + std::unique_ptr> fft; + std::unique_ptr> rfft; + arr> C2; + + public: + POCKETFFT_NOINLINE T_dcst4(size_t length) + : N(length), + fft((N&1) ? nullptr : new pocketfft_c(N/2)), + rfft((N&1)? new pocketfft_r(N) : nullptr), + C2((N&1) ? 0 : N/2) + { + if ((N&1)==0) + { + sincos_2pibyn tw(16*N); + for (size_t i=0; i POCKETFFT_NOINLINE void exec(T c[], T0 fct, + bool /*ortho*/, int /*type*/, bool cosine) const + { + size_t n2 = N/2; + if (!cosine) + for (size_t k=0, kc=N-1; k y(N); + { + size_t i=0, m=n2; + for (; mexec(y.data(), fct, true); + { + auto SGN = [](size_t i) + { + constexpr T0 sqrt2=T0(1.414213562373095048801688724209698L); + return (i&2) ? -sqrt2 : sqrt2; + }; + c[n2] = y[0]*SGN(n2+1); + size_t i=0, i1=1, k=1; + for (; k> y(n2); + for(size_t i=0; iexec(y.data(), fct, true); + for(size_t i=0, ic=n2-1; i std::shared_ptr get_plan(size_t length) + { +#if POCKETFFT_CACHE_SIZE==0 + return std::make_shared(length); +#else + constexpr size_t nmax=POCKETFFT_CACHE_SIZE; + static std::array, nmax> cache; + static std::array last_access{{0}}; + static size_t access_counter = 0; + static std::mutex mut; + + auto find_in_cache = [&]() -> std::shared_ptr + { + for (size_t i=0; ilength()==length)) + { + // no need to update if this is already the most recent entry + if (last_access[i]!=access_counter) + { + last_access[i] = ++access_counter; + // Guard against overflow + if (access_counter == 0) + last_access.fill(0); + } + return cache[i]; + } + + return nullptr; + }; + + { + std::lock_guard lock(mut); + auto p = find_in_cache(); + if (p) return p; + } + auto plan = std::make_shared(length); + { + std::lock_guard lock(mut); + auto p = find_in_cache(); + if (p) return p; + + size_t lru = 0; + for (size_t i=1; i class cndarr: public arr_info + { + protected: + const char *d; + + public: + cndarr(const void *data_, const shape_t &shape_, const stride_t &stride_) + : arr_info(shape_, stride_), + d(reinterpret_cast(data_)) {} + const T &operator[](ptrdiff_t ofs) const + { return *reinterpret_cast(d+ofs); } + }; + +template class ndarr: public cndarr + { + public: + ndarr(void *data_, const shape_t &shape_, const stride_t &stride_) + : cndarr::cndarr(const_cast(data_), shape_, stride_) + {} + T &operator[](ptrdiff_t ofs) + { return *reinterpret_cast(const_cast(cndarr::d+ofs)); } + }; + +template class multi_iter + { + private: + shape_t pos; + const arr_info &iarr, &oarr; + ptrdiff_t p_ii, p_i[N], str_i, p_oi, p_o[N], str_o; + size_t idim, rem; + + void advance_i() + { + for (int i_=int(pos.size())-1; i_>=0; --i_) + { + auto i = size_t(i_); + if (i==idim) continue; + p_ii += iarr.stride(i); + p_oi += oarr.stride(i); + if (++pos[i] < iarr.shape(i)) + return; + pos[i] = 0; + p_ii -= ptrdiff_t(iarr.shape(i))*iarr.stride(i); + p_oi -= ptrdiff_t(oarr.shape(i))*oarr.stride(i); + } + } + + public: + multi_iter(const arr_info &iarr_, const arr_info &oarr_, size_t idim_) + : pos(iarr_.ndim(), 0), iarr(iarr_), oarr(oarr_), p_ii(0), + str_i(iarr.stride(idim_)), p_oi(0), str_o(oarr.stride(idim_)), + idim(idim_), rem(iarr.size()/iarr.shape(idim)) + { + auto nshares = threading::num_threads(); + if (nshares==1) return; + if (nshares==0) throw std::runtime_error("can't run with zero threads"); + auto myshare = threading::thread_id(); + if (myshare>=nshares) throw std::runtime_error("impossible share requested"); + size_t nbase = rem/nshares; + size_t additional = rem%nshares; + size_t lo = myshare*nbase + ((myshare=0; --i_) + { + auto i = size_t(i_); + p += arr.stride(i); + if (++pos[i] < arr.shape(i)) + return; + pos[i] = 0; + p -= ptrdiff_t(arr.shape(i))*arr.stride(i); + } + } + ptrdiff_t ofs() const { return p; } + size_t remaining() const { return rem; } + }; + +class rev_iter + { + private: + shape_t pos; + const arr_info &arr; + std::vector rev_axis; + std::vector rev_jump; + size_t last_axis, last_size; + shape_t shp; + ptrdiff_t p, rp; + size_t rem; + + public: + rev_iter(const arr_info &arr_, const shape_t &axes) + : pos(arr_.ndim(), 0), arr(arr_), rev_axis(arr_.ndim(), 0), + rev_jump(arr_.ndim(), 1), p(0), rp(0) + { + for (auto ax: axes) + rev_axis[ax]=1; + last_axis = axes.back(); + last_size = arr.shape(last_axis)/2 + 1; + shp = arr.shape(); + shp[last_axis] = last_size; + rem=1; + for (auto i: shp) + rem *= i; + } + void advance() + { + --rem; + for (int i_=int(pos.size())-1; i_>=0; --i_) + { + auto i = size_t(i_); + p += arr.stride(i); + if (!rev_axis[i]) + rp += arr.stride(i); + else + { + rp -= arr.stride(i); + if (rev_jump[i]) + { + rp += ptrdiff_t(arr.shape(i))*arr.stride(i); + rev_jump[i] = 0; + } + } + if (++pos[i] < shp[i]) + return; + pos[i] = 0; + p -= ptrdiff_t(shp[i])*arr.stride(i); + if (rev_axis[i]) + { + rp -= ptrdiff_t(arr.shape(i)-shp[i])*arr.stride(i); + rev_jump[i] = 1; + } + else + rp -= ptrdiff_t(shp[i])*arr.stride(i); + } + } + ptrdiff_t ofs() const { return p; } + ptrdiff_t rev_ofs() const { return rp; } + size_t remaining() const { return rem; } + }; + +template struct VTYPE {}; +template using vtype_t = typename VTYPE::type; + +#ifndef POCKETFFT_NO_VECTORS +template<> struct VTYPE + { + using type = float __attribute__ ((vector_size (VLEN::val*sizeof(float)))); + }; +template<> struct VTYPE + { + using type = double __attribute__ ((vector_size (VLEN::val*sizeof(double)))); + }; +template<> struct VTYPE + { + using type = long double __attribute__ ((vector_size (VLEN::val*sizeof(long double)))); + }; +#endif + +template arr alloc_tmp(const shape_t &shape, + size_t axsize, size_t elemsize) + { + auto othersize = util::prod(shape)/axsize; + auto tmpsize = axsize*((othersize>=VLEN::val) ? VLEN::val : 1); + return arr(tmpsize*elemsize); + } +template arr alloc_tmp(const shape_t &shape, + const shape_t &axes, size_t elemsize) + { + size_t fullsize=util::prod(shape); + size_t tmpsize=0; + for (size_t i=0; i=VLEN::val) ? VLEN::val : 1); + if (sz>tmpsize) tmpsize=sz; + } + return arr(tmpsize*elemsize); + } + +template void copy_input(const multi_iter &it, + const cndarr> &src, cmplx> *POCKETFFT_RESTRICT dst) + { + for (size_t i=0; i void copy_input(const multi_iter &it, + const cndarr &src, vtype_t *POCKETFFT_RESTRICT dst) + { + for (size_t i=0; i void copy_input(const multi_iter &it, + const cndarr &src, T *POCKETFFT_RESTRICT dst) + { + if (dst == &src[it.iofs(0)]) return; // in-place + for (size_t i=0; i void copy_output(const multi_iter &it, + const cmplx> *POCKETFFT_RESTRICT src, ndarr> &dst) + { + for (size_t i=0; i void copy_output(const multi_iter &it, + const vtype_t *POCKETFFT_RESTRICT src, ndarr &dst) + { + for (size_t i=0; i void copy_output(const multi_iter &it, + const T *POCKETFFT_RESTRICT src, ndarr &dst) + { + if (src == &dst[it.oofs(0)]) return; // in-place + for (size_t i=0; i struct add_vec { using type = vtype_t; }; +template struct add_vec> + { using type = cmplx>; }; +template using add_vec_t = typename add_vec::type; + +template +POCKETFFT_NOINLINE void general_nd(const cndarr &in, ndarr &out, + const shape_t &axes, T0 fct, size_t nthreads, const Exec & exec, + const bool allow_inplace=true) + { + std::shared_ptr plan; + + for (size_t iax=0; iaxlength())) + plan = get_plan(len); + + threading::thread_map( + util::thread_count(nthreads, in.shape(), axes[iax], VLEN::val), + [&] { + constexpr auto vlen = VLEN::val; + auto storage = alloc_tmp(in.shape(), len, sizeof(T)); + const auto &tin(iax==0? in : out); + multi_iter it(tin, out, axes[iax]); +#ifndef POCKETFFT_NO_VECTORS + if (vlen>1) + while (it.remaining()>=vlen) + { + it.advance(vlen); + auto tdatav = reinterpret_cast *>(storage.data()); + exec(it, tin, out, tdatav, *plan, fct); + } +#endif + while (it.remaining()>0) + { + it.advance(1); + auto buf = allow_inplace && it.stride_out() == sizeof(T) ? + &out[it.oofs(0)] : reinterpret_cast(storage.data()); + exec(it, tin, out, buf, *plan, fct); + } + }); // end of parallel region + fct = T0(1); // factor has been applied, use 1 for remaining axes + } + } + +struct ExecC2C + { + bool forward; + + template void operator () ( + const multi_iter &it, const cndarr> &in, + ndarr> &out, T * buf, const pocketfft_c &plan, T0 fct) const + { + copy_input(it, in, buf); + plan.exec(buf, fct, forward); + copy_output(it, buf, out); + } + }; + +template void copy_hartley(const multi_iter &it, + const vtype_t *POCKETFFT_RESTRICT src, ndarr &dst) + { + for (size_t j=0; j void copy_hartley(const multi_iter &it, + const T *POCKETFFT_RESTRICT src, ndarr &dst) + { + dst[it.oofs(0)] = src[0]; + size_t i=1, i1=1, i2=it.length_out()-1; + for (i=1; i void operator () ( + const multi_iter &it, const cndarr &in, ndarr &out, + T * buf, const pocketfft_r &plan, T0 fct) const + { + copy_input(it, in, buf); + plan.exec(buf, fct, true); + copy_hartley(it, buf, out); + } + }; + +struct ExecDcst + { + bool ortho; + int type; + bool cosine; + + template + void operator () (const multi_iter &it, const cndarr &in, + ndarr &out, T * buf, const Tplan &plan, T0 fct) const + { + copy_input(it, in, buf); + plan.exec(buf, fct, ortho, type, cosine); + copy_output(it, buf, out); + } + }; + +template POCKETFFT_NOINLINE void general_r2c( + const cndarr &in, ndarr> &out, size_t axis, bool forward, T fct, + size_t nthreads) + { + auto plan = get_plan>(in.shape(axis)); + size_t len=in.shape(axis); + threading::thread_map( + util::thread_count(nthreads, in.shape(), axis, VLEN::val), + [&] { + constexpr auto vlen = VLEN::val; + auto storage = alloc_tmp(in.shape(), len, sizeof(T)); + multi_iter it(in, out, axis); +#ifndef POCKETFFT_NO_VECTORS + if (vlen>1) + while (it.remaining()>=vlen) + { + it.advance(vlen); + auto tdatav = reinterpret_cast *>(storage.data()); + copy_input(it, in, tdatav); + plan->exec(tdatav, fct, true); + for (size_t j=0; j0) + { + it.advance(1); + auto tdata = reinterpret_cast(storage.data()); + copy_input(it, in, tdata); + plan->exec(tdata, fct, true); + out[it.oofs(0)].Set(tdata[0]); + size_t i=1, ii=1; + if (forward) + for (; i POCKETFFT_NOINLINE void general_c2r( + const cndarr> &in, ndarr &out, size_t axis, bool forward, T fct, + size_t nthreads) + { + auto plan = get_plan>(out.shape(axis)); + size_t len=out.shape(axis); + threading::thread_map( + util::thread_count(nthreads, in.shape(), axis, VLEN::val), + [&] { + constexpr auto vlen = VLEN::val; + auto storage = alloc_tmp(out.shape(), len, sizeof(T)); + multi_iter it(in, out, axis); +#ifndef POCKETFFT_NO_VECTORS + if (vlen>1) + while (it.remaining()>=vlen) + { + it.advance(vlen); + auto tdatav = reinterpret_cast *>(storage.data()); + for (size_t j=0; jexec(tdatav, fct, false); + copy_output(it, tdatav, out); + } +#endif + while (it.remaining()>0) + { + it.advance(1); + auto tdata = reinterpret_cast(storage.data()); + tdata[0]=in[it.iofs(0)].r; + { + size_t i=1, ii=1; + if (forward) + for (; iexec(tdata, fct, false); + copy_output(it, tdata, out); + } + }); // end of parallel region + } + +struct ExecR2R + { + bool r2h, forward; + + template void operator () ( + const multi_iter &it, const cndarr &in, ndarr &out, T * buf, + const pocketfft_r &plan, T0 fct) const + { + copy_input(it, in, buf); + if ((!r2h) && forward) + for (size_t i=2; i void c2c(const shape_t &shape, const stride_t &stride_in, + const stride_t &stride_out, const shape_t &axes, bool forward, + const std::complex *data_in, std::complex *data_out, T fct, + size_t nthreads=1) + { + if (util::prod(shape)==0) return; + util::sanity_check(shape, stride_in, stride_out, data_in==data_out, axes); + cndarr> ain(data_in, shape, stride_in); + ndarr> aout(data_out, shape, stride_out); + general_nd>(ain, aout, axes, fct, nthreads, ExecC2C{forward}); + } + +template void dct(const shape_t &shape, + const stride_t &stride_in, const stride_t &stride_out, const shape_t &axes, + int type, const T *data_in, T *data_out, T fct, bool ortho, size_t nthreads=1) + { + if ((type<1) || (type>4)) throw std::invalid_argument("invalid DCT type"); + if (util::prod(shape)==0) return; + util::sanity_check(shape, stride_in, stride_out, data_in==data_out, axes); + cndarr ain(data_in, shape, stride_in); + ndarr aout(data_out, shape, stride_out); + const ExecDcst exec{ortho, type, true}; + if (type==1) + general_nd>(ain, aout, axes, fct, nthreads, exec); + else if (type==4) + general_nd>(ain, aout, axes, fct, nthreads, exec); + else + general_nd>(ain, aout, axes, fct, nthreads, exec); + } + +template void dst(const shape_t &shape, + const stride_t &stride_in, const stride_t &stride_out, const shape_t &axes, + int type, const T *data_in, T *data_out, T fct, bool ortho, size_t nthreads=1) + { + if ((type<1) || (type>4)) throw std::invalid_argument("invalid DST type"); + if (util::prod(shape)==0) return; + util::sanity_check(shape, stride_in, stride_out, data_in==data_out, axes); + cndarr ain(data_in, shape, stride_in); + ndarr aout(data_out, shape, stride_out); + const ExecDcst exec{ortho, type, false}; + if (type==1) + general_nd>(ain, aout, axes, fct, nthreads, exec); + else if (type==4) + general_nd>(ain, aout, axes, fct, nthreads, exec); + else + general_nd>(ain, aout, axes, fct, nthreads, exec); + } + +template void r2c(const shape_t &shape_in, + const stride_t &stride_in, const stride_t &stride_out, size_t axis, + bool forward, const T *data_in, std::complex *data_out, T fct, + size_t nthreads=1) + { + if (util::prod(shape_in)==0) return; + util::sanity_check(shape_in, stride_in, stride_out, false, axis); + cndarr ain(data_in, shape_in, stride_in); + shape_t shape_out(shape_in); + shape_out[axis] = shape_in[axis]/2 + 1; + ndarr> aout(data_out, shape_out, stride_out); + general_r2c(ain, aout, axis, forward, fct, nthreads); + } + +template void r2c(const shape_t &shape_in, + const stride_t &stride_in, const stride_t &stride_out, const shape_t &axes, + bool forward, const T *data_in, std::complex *data_out, T fct, + size_t nthreads=1) + { + if (util::prod(shape_in)==0) return; + util::sanity_check(shape_in, stride_in, stride_out, false, axes); + r2c(shape_in, stride_in, stride_out, axes.back(), forward, data_in, data_out, + fct, nthreads); + if (axes.size()==1) return; + + shape_t shape_out(shape_in); + shape_out[axes.back()] = shape_in[axes.back()]/2 + 1; + auto newaxes = shape_t{axes.begin(), --axes.end()}; + c2c(shape_out, stride_out, stride_out, newaxes, forward, data_out, data_out, + T(1), nthreads); + } + +template void c2r(const shape_t &shape_out, + const stride_t &stride_in, const stride_t &stride_out, size_t axis, + bool forward, const std::complex *data_in, T *data_out, T fct, + size_t nthreads=1) + { + if (util::prod(shape_out)==0) return; + util::sanity_check(shape_out, stride_in, stride_out, false, axis); + shape_t shape_in(shape_out); + shape_in[axis] = shape_out[axis]/2 + 1; + cndarr> ain(data_in, shape_in, stride_in); + ndarr aout(data_out, shape_out, stride_out); + general_c2r(ain, aout, axis, forward, fct, nthreads); + } + +template void c2r(const shape_t &shape_out, + const stride_t &stride_in, const stride_t &stride_out, const shape_t &axes, + bool forward, const std::complex *data_in, T *data_out, T fct, + size_t nthreads=1) + { + if (util::prod(shape_out)==0) return; + if (axes.size()==1) + return c2r(shape_out, stride_in, stride_out, axes[0], forward, + data_in, data_out, fct, nthreads); + util::sanity_check(shape_out, stride_in, stride_out, false, axes); + auto shape_in = shape_out; + shape_in[axes.back()] = shape_out[axes.back()]/2 + 1; + auto nval = util::prod(shape_in); + stride_t stride_inter(shape_in.size()); + stride_inter.back() = sizeof(cmplx); + for (int i=int(shape_in.size())-2; i>=0; --i) + stride_inter[size_t(i)] = + stride_inter[size_t(i+1)]*ptrdiff_t(shape_in[size_t(i+1)]); + arr> tmp(nval); + auto newaxes = shape_t{axes.begin(), --axes.end()}; + c2c(shape_in, stride_in, stride_inter, newaxes, forward, data_in, tmp.data(), + T(1), nthreads); + c2r(shape_out, stride_inter, stride_out, axes.back(), forward, + tmp.data(), data_out, fct, nthreads); + } + +template void r2r_fftpack(const shape_t &shape, + const stride_t &stride_in, const stride_t &stride_out, const shape_t &axes, + bool real2hermitian, bool forward, const T *data_in, T *data_out, T fct, + size_t nthreads=1) + { + if (util::prod(shape)==0) return; + util::sanity_check(shape, stride_in, stride_out, data_in==data_out, axes); + cndarr ain(data_in, shape, stride_in); + ndarr aout(data_out, shape, stride_out); + general_nd>(ain, aout, axes, fct, nthreads, + ExecR2R{real2hermitian, forward}); + } + +template void r2r_separable_hartley(const shape_t &shape, + const stride_t &stride_in, const stride_t &stride_out, const shape_t &axes, + const T *data_in, T *data_out, T fct, size_t nthreads=1) + { + if (util::prod(shape)==0) return; + util::sanity_check(shape, stride_in, stride_out, data_in==data_out, axes); + cndarr ain(data_in, shape, stride_in); + ndarr aout(data_out, shape, stride_out); + general_nd>(ain, aout, axes, fct, nthreads, ExecHartley{}, + false); + } + +template void r2r_genuine_hartley(const shape_t &shape, + const stride_t &stride_in, const stride_t &stride_out, const shape_t &axes, + const T *data_in, T *data_out, T fct, size_t nthreads=1) + { + if (util::prod(shape)==0) return; + if (axes.size()==1) + return r2r_separable_hartley(shape, stride_in, stride_out, axes, data_in, + data_out, fct, nthreads); + util::sanity_check(shape, stride_in, stride_out, data_in==data_out, axes); + shape_t tshp(shape); + tshp[axes.back()] = tshp[axes.back()]/2+1; + arr> tdata(util::prod(tshp)); + stride_t tstride(shape.size()); + tstride.back()=sizeof(std::complex); + for (size_t i=tstride.size()-1; i>0; --i) + tstride[i-1]=tstride[i]*ptrdiff_t(tshp[i]); + r2c(shape, stride_in, tstride, axes, true, data_in, tdata.data(), fct, nthreads); + cndarr> atmp(tdata.data(), tshp, tstride); + ndarr aout(data_out, shape, stride_out); + simple_iter iin(atmp); + rev_iter iout(aout, axes); + while(iin.remaining()>0) + { + auto v = atmp[iin.ofs()]; + aout[iout.ofs()] = v.r+v.i; + aout[iout.rev_ofs()] = v.r-v.i; + iin.advance(); iout.advance(); + } + } + +} // namespace detail + +using detail::FORWARD; +using detail::BACKWARD; +using detail::shape_t; +using detail::stride_t; +using detail::c2c; +using detail::c2r; +using detail::r2c; +using detail::r2r_fftpack; +using detail::r2r_separable_hartley; +using detail::r2r_genuine_hartley; +using detail::dct; +using detail::dst; + +} // namespace pocketfft + +#undef POCKETFFT_NOINLINE +#undef POCKETFFT_RESTRICT + +#endif // POCKETFFT_HDRONLY_H