#if !defined(C10_INTERNAL_INCLUDE_COMPLEX_REMAINING_H) #error \ "c10/util/complex_math.h is not meant to be individually included. Include c10/util/complex.h instead." #endif namespace c10_complex_math { // Exponential functions template C10_HOST_DEVICE inline c10::complex exp(const c10::complex& x) { #if defined(__CUDACC__) || defined(__HIPCC__) return static_cast>(thrust::exp( c10_internal::cuda101bug_cast_c10_complex_to_thrust_complex(x))); #else return static_cast>( std::exp(static_cast>(x))); #endif } template C10_HOST_DEVICE inline c10::complex log(const c10::complex& x) { #if defined(__CUDACC__) || defined(__HIPCC__) return static_cast>(thrust::log( c10_internal::cuda101bug_cast_c10_complex_to_thrust_complex(x))); #else return static_cast>( std::log(static_cast>(x))); #endif } template C10_HOST_DEVICE inline c10::complex log10(const c10::complex& x) { #if defined(__CUDACC__) || defined(__HIPCC__) return static_cast>(thrust::log10( c10_internal::cuda101bug_cast_c10_complex_to_thrust_complex(x))); #else return static_cast>( std::log10(static_cast>(x))); #endif } template C10_HOST_DEVICE inline c10::complex log2(const c10::complex& x) { const c10::complex log2 = c10::complex(::log(2.0), 0.0); return c10_complex_math::log(x) / log2; } // Power functions // #if defined(_LIBCPP_VERSION) || \ (defined(__GLIBCXX__) && !defined(_GLIBCXX11_USE_C99_COMPLEX)) namespace _detail { TORCH_API c10::complex sqrt(const c10::complex& in); TORCH_API c10::complex sqrt(const c10::complex& in); TORCH_API c10::complex acos(const c10::complex& in); TORCH_API c10::complex acos(const c10::complex& in); }; // namespace _detail #endif template C10_HOST_DEVICE inline c10::complex sqrt(const c10::complex& x) { #if defined(__CUDACC__) || defined(__HIPCC__) return static_cast>(thrust::sqrt( c10_internal::cuda101bug_cast_c10_complex_to_thrust_complex(x))); #elif !( \ defined(_LIBCPP_VERSION) || \ (defined(__GLIBCXX__) && !defined(_GLIBCXX11_USE_C99_COMPLEX))) return static_cast>( std::sqrt(static_cast>(x))); #else return _detail::sqrt(x); #endif } template C10_HOST_DEVICE inline c10::complex pow( const c10::complex& x, const c10::complex& y) { #if defined(__CUDACC__) || defined(__HIPCC__) return static_cast>(thrust::pow( c10_internal::cuda101bug_cast_c10_complex_to_thrust_complex(x), c10_internal::cuda101bug_cast_c10_complex_to_thrust_complex(y))); #else return static_cast>(std::pow( static_cast>(x), static_cast>(y))); #endif } template C10_HOST_DEVICE inline c10::complex pow( const c10::complex& x, const T& y) { #if defined(__CUDACC__) || defined(__HIPCC__) return static_cast>(thrust::pow( c10_internal::cuda101bug_cast_c10_complex_to_thrust_complex(x), y)); #else return static_cast>( std::pow(static_cast>(x), y)); #endif } template C10_HOST_DEVICE inline c10::complex pow( const T& x, const c10::complex& y) { #if defined(__CUDACC__) || defined(__HIPCC__) return static_cast>(thrust::pow( x, c10_internal::cuda101bug_cast_c10_complex_to_thrust_complex(y))); #else return static_cast>( std::pow(x, static_cast>(y))); #endif } template C10_HOST_DEVICE inline c10::complex pow( const c10::complex& x, const c10::complex& y) { #if defined(__CUDACC__) || defined(__HIPCC__) return static_cast>(thrust::pow( c10_internal::cuda101bug_cast_c10_complex_to_thrust_complex(x), c10_internal::cuda101bug_cast_c10_complex_to_thrust_complex(y))); #else return static_cast>(std::pow( static_cast>(x), static_cast>(y))); #endif } template C10_HOST_DEVICE inline c10::complex pow( const c10::complex& x, const U& y) { #if defined(__CUDACC__) || defined(__HIPCC__) return static_cast>(thrust::pow( c10_internal::cuda101bug_cast_c10_complex_to_thrust_complex(x), y)); #else return static_cast>( std::pow(static_cast>(x), y)); #endif } template C10_HOST_DEVICE inline c10::complex pow( const T& x, const c10::complex& y) { #if defined(__CUDACC__) || defined(__HIPCC__) return static_cast>(thrust::pow( x, c10_internal::cuda101bug_cast_c10_complex_to_thrust_complex(y))); #else return static_cast>( std::pow(x, static_cast>(y))); #endif } // Trigonometric functions template C10_HOST_DEVICE inline c10::complex sin(const c10::complex& x) { #if defined(__CUDACC__) || defined(__HIPCC__) return static_cast>(thrust::sin( c10_internal::cuda101bug_cast_c10_complex_to_thrust_complex(x))); #else return static_cast>( std::sin(static_cast>(x))); #endif } template C10_HOST_DEVICE inline c10::complex cos(const c10::complex& x) { #if defined(__CUDACC__) || defined(__HIPCC__) return static_cast>(thrust::cos( c10_internal::cuda101bug_cast_c10_complex_to_thrust_complex(x))); #else return static_cast>( std::cos(static_cast>(x))); #endif } template C10_HOST_DEVICE inline c10::complex tan(const c10::complex& x) { #if defined(__CUDACC__) || defined(__HIPCC__) return static_cast>(thrust::tan( c10_internal::cuda101bug_cast_c10_complex_to_thrust_complex(x))); #else return static_cast>( std::tan(static_cast>(x))); #endif } template C10_HOST_DEVICE inline c10::complex asin(const c10::complex& x) { #if defined(__CUDACC__) || defined(__HIPCC__) return static_cast>(thrust::asin( c10_internal::cuda101bug_cast_c10_complex_to_thrust_complex(x))); #else return static_cast>( std::asin(static_cast>(x))); #endif } template C10_HOST_DEVICE inline c10::complex acos(const c10::complex& x) { #if defined(__CUDACC__) || defined(__HIPCC__) return static_cast>(thrust::acos( c10_internal::cuda101bug_cast_c10_complex_to_thrust_complex(x))); #elif !defined(_LIBCPP_VERSION) return static_cast>( std::acos(static_cast>(x))); #else return _detail::acos(x); #endif } template C10_HOST_DEVICE inline c10::complex atan(const c10::complex& x) { #if defined(__CUDACC__) || defined(__HIPCC__) return static_cast>(thrust::atan( c10_internal::cuda101bug_cast_c10_complex_to_thrust_complex(x))); #else return static_cast>( std::atan(static_cast>(x))); #endif } // Hyperbolic functions template C10_HOST_DEVICE inline c10::complex sinh(const c10::complex& x) { #if defined(__CUDACC__) || defined(__HIPCC__) return static_cast>(thrust::sinh( c10_internal::cuda101bug_cast_c10_complex_to_thrust_complex(x))); #else return static_cast>( std::sinh(static_cast>(x))); #endif } template C10_HOST_DEVICE inline c10::complex cosh(const c10::complex& x) { #if defined(__CUDACC__) || defined(__HIPCC__) return static_cast>(thrust::cosh( c10_internal::cuda101bug_cast_c10_complex_to_thrust_complex(x))); #else return static_cast>( std::cosh(static_cast>(x))); #endif } template C10_HOST_DEVICE inline c10::complex tanh(const c10::complex& x) { #if defined(__CUDACC__) || defined(__HIPCC__) return static_cast>(thrust::tanh( c10_internal::cuda101bug_cast_c10_complex_to_thrust_complex(x))); #else return static_cast>( std::tanh(static_cast>(x))); #endif } template C10_HOST_DEVICE inline c10::complex asinh(const c10::complex& x) { #if defined(__CUDACC__) || defined(__HIPCC__) return static_cast>(thrust::asinh( c10_internal::cuda101bug_cast_c10_complex_to_thrust_complex(x))); #else return static_cast>( std::asinh(static_cast>(x))); #endif } template C10_HOST_DEVICE inline c10::complex acosh(const c10::complex& x) { #if defined(__CUDACC__) || defined(__HIPCC__) return static_cast>(thrust::acosh( c10_internal::cuda101bug_cast_c10_complex_to_thrust_complex(x))); #else return static_cast>( std::acosh(static_cast>(x))); #endif } template C10_HOST_DEVICE inline c10::complex atanh(const c10::complex& x) { #if defined(__CUDACC__) || defined(__HIPCC__) return static_cast>(thrust::atanh( c10_internal::cuda101bug_cast_c10_complex_to_thrust_complex(x))); #else return static_cast>( std::atanh(static_cast>(x))); #endif } } // namespace c10_complex_math using c10_complex_math::acos; using c10_complex_math::acosh; using c10_complex_math::asin; using c10_complex_math::asinh; using c10_complex_math::atan; using c10_complex_math::atanh; using c10_complex_math::cos; using c10_complex_math::cosh; using c10_complex_math::exp; using c10_complex_math::log; using c10_complex_math::log10; using c10_complex_math::log2; using c10_complex_math::pow; using c10_complex_math::sin; using c10_complex_math::sinh; using c10_complex_math::sqrt; using c10_complex_math::tan; using c10_complex_math::tanh; namespace std { using c10_complex_math::acos; using c10_complex_math::acosh; using c10_complex_math::asin; using c10_complex_math::asinh; using c10_complex_math::atan; using c10_complex_math::atanh; using c10_complex_math::cos; using c10_complex_math::cosh; using c10_complex_math::exp; using c10_complex_math::log; using c10_complex_math::log10; using c10_complex_math::log2; using c10_complex_math::pow; using c10_complex_math::sin; using c10_complex_math::sinh; using c10_complex_math::sqrt; using c10_complex_math::tan; using c10_complex_math::tanh; } // namespace std