Spaces:
Sleeping
Sleeping
/*************************************************************************************************** | |
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
* SPDX-License-Identifier: BSD-3-Clause | |
* | |
* Redistribution and use in source and binary forms, with or without | |
* modification, are permitted provided that the following conditions are met: | |
* | |
* 1. Redistributions of source code must retain the above copyright notice, this | |
* list of conditions and the following disclaimer. | |
* | |
* 2. Redistributions in binary form must reproduce the above copyright notice, | |
* this list of conditions and the following disclaimer in the documentation | |
* and/or other materials provided with the distribution. | |
* | |
* 3. Neither the name of the copyright holder nor the names of its | |
* contributors may be used to endorse or promote products derived from | |
* this software without specific prior written permission. | |
* | |
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE | |
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE | |
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | |
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR | |
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER | |
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, | |
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE | |
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |
* | |
**************************************************************************************************/ | |
/** C++14 <functional> extensions */ | |
namespace cute { | |
/**************/ | |
/** Identity **/ | |
/**************/ | |
struct identity { | |
template <class T> | |
CUTE_HOST_DEVICE constexpr | |
decltype(auto) operator()(T&& arg) const { | |
return static_cast<T&&>(arg); | |
} | |
}; | |
template <class R> | |
struct constant_fn { | |
template <class... T> | |
CUTE_HOST_DEVICE constexpr | |
decltype(auto) operator()(T&&...) const { | |
return r_; | |
} | |
R r_; | |
}; | |
/***********/ | |
/** Unary **/ | |
/***********/ | |
struct NAME { \ | |
template <class T> \ | |
CUTE_HOST_DEVICE constexpr \ | |
decltype(auto) operator()(T&& arg) const { \ | |
return OP static_cast<T&&>(arg); \ | |
} \ | |
} | |
struct NAME { \ | |
template <class T> \ | |
CUTE_HOST_DEVICE constexpr \ | |
decltype(auto) operator()(T&& arg) const { \ | |
return static_cast<T&&>(arg) OP ; \ | |
} \ | |
} | |
struct NAME { \ | |
template <class T> \ | |
CUTE_HOST_DEVICE constexpr \ | |
decltype(auto) operator()(T&& arg) const { \ | |
return OP (static_cast<T&&>(arg)); \ | |
} \ | |
} | |
CUTE_LEFT_UNARY_OP(unary_plus, +); | |
CUTE_LEFT_UNARY_OP(negate, -); | |
CUTE_LEFT_UNARY_OP(bit_not, ~); | |
CUTE_LEFT_UNARY_OP(logical_not, !); | |
CUTE_LEFT_UNARY_OP(dereference, *); | |
CUTE_LEFT_UNARY_OP(address_of, &); | |
CUTE_LEFT_UNARY_OP(pre_increment, ++); | |
CUTE_LEFT_UNARY_OP(pre_decrement, --); | |
CUTE_RIGHT_UNARY_OP(post_increment, ++); | |
CUTE_RIGHT_UNARY_OP(post_decrement, --); | |
CUTE_NAMED_UNARY_OP(abs_fn, abs); | |
CUTE_NAMED_UNARY_OP(conjugate, cute::conj); | |
template <int Shift_> | |
struct shift_right_const { | |
static constexpr int Shift = Shift_; | |
template <class T> | |
CUTE_HOST_DEVICE constexpr | |
decltype(auto) operator()(T&& arg) const { | |
return static_cast<T&&>(arg) >> Shift; | |
} | |
}; | |
template <int Shift_> | |
struct shift_left_const { | |
static constexpr int Shift = Shift_; | |
template <class T> | |
CUTE_HOST_DEVICE constexpr | |
decltype(auto) operator()(T&& arg) const { | |
return static_cast<T&&>(arg) << Shift; | |
} | |
}; | |
/************/ | |
/** Binary **/ | |
/************/ | |
struct NAME { \ | |
template <class T, class U> \ | |
CUTE_HOST_DEVICE constexpr \ | |
decltype(auto) operator()(T&& lhs, U&& rhs) const { \ | |
return static_cast<T&&>(lhs) OP static_cast<U&&>(rhs); \ | |
} \ | |
} | |
struct NAME { \ | |
template <class T, class U> \ | |
CUTE_HOST_DEVICE constexpr \ | |
decltype(auto) operator()(T&& lhs, U&& rhs) const { \ | |
return OP (static_cast<T&&>(lhs), static_cast<U&&>(rhs)); \ | |
} \ | |
} | |
CUTE_BINARY_OP(plus, +); | |
CUTE_BINARY_OP(minus, -); | |
CUTE_BINARY_OP(multiplies, *); | |
CUTE_BINARY_OP(divides, /); | |
CUTE_BINARY_OP(modulus, %); | |
CUTE_BINARY_OP(plus_assign, +=); | |
CUTE_BINARY_OP(minus_assign, -=); | |
CUTE_BINARY_OP(multiplies_assign, *=); | |
CUTE_BINARY_OP(divides_assign, /=); | |
CUTE_BINARY_OP(modulus_assign, %=); | |
CUTE_BINARY_OP(bit_and, &); | |
CUTE_BINARY_OP(bit_or, |); | |
CUTE_BINARY_OP(bit_xor, ^); | |
CUTE_BINARY_OP(left_shift, <<); | |
CUTE_BINARY_OP(right_shift, >>); | |
CUTE_BINARY_OP(bit_and_assign, &=); | |
CUTE_BINARY_OP(bit_or_assign, |=); | |
CUTE_BINARY_OP(bit_xor_assign, ^=); | |
CUTE_BINARY_OP(left_shift_assign, <<=); | |
CUTE_BINARY_OP(right_shift_assign, >>=); | |
CUTE_BINARY_OP(logical_and, &&); | |
CUTE_BINARY_OP(logical_or, ||); | |
CUTE_BINARY_OP(equal_to, ==); | |
CUTE_BINARY_OP(not_equal_to, !=); | |
CUTE_BINARY_OP(greater, >); | |
CUTE_BINARY_OP(less, <); | |
CUTE_BINARY_OP(greater_equal, >=); | |
CUTE_BINARY_OP(less_equal, <=); | |
CUTE_NAMED_BINARY_OP(max_fn, cute::max); | |
CUTE_NAMED_BINARY_OP(min_fn, cute::min); | |
/**********/ | |
/** Fold **/ | |
/**********/ | |
struct NAME##_unary_rfold { \ | |
template <class... T> \ | |
CUTE_HOST_DEVICE constexpr \ | |
auto operator()(T&&... t) const { \ | |
return (t OP ...); \ | |
} \ | |
}; \ | |
struct NAME##_unary_lfold { \ | |
template <class... T> \ | |
CUTE_HOST_DEVICE constexpr \ | |
auto operator()(T&&... t) const { \ | |
return (... OP t); \ | |
} \ | |
}; \ | |
struct NAME##_binary_rfold { \ | |
template <class U, class... T> \ | |
CUTE_HOST_DEVICE constexpr \ | |
auto operator()(U&& u, T&&... t) const { \ | |
return (t OP ... OP u); \ | |
} \ | |
}; \ | |
struct NAME##_binary_lfold { \ | |
template <class U, class... T> \ | |
CUTE_HOST_DEVICE constexpr \ | |
auto operator()(U&& u, T&&... t) const { \ | |
return (u OP ... OP t); \ | |
} \ | |
} | |
CUTE_FOLD_OP(plus, +); | |
CUTE_FOLD_OP(minus, -); | |
CUTE_FOLD_OP(multiplies, *); | |
CUTE_FOLD_OP(divides, /); | |
CUTE_FOLD_OP(modulus, %); | |
CUTE_FOLD_OP(plus_assign, +=); | |
CUTE_FOLD_OP(minus_assign, -=); | |
CUTE_FOLD_OP(multiplies_assign, *=); | |
CUTE_FOLD_OP(divides_assign, /=); | |
CUTE_FOLD_OP(modulus_assign, %=); | |
CUTE_FOLD_OP(bit_and, &); | |
CUTE_FOLD_OP(bit_or, |); | |
CUTE_FOLD_OP(bit_xor, ^); | |
CUTE_FOLD_OP(left_shift, <<); | |
CUTE_FOLD_OP(right_shift, >>); | |
CUTE_FOLD_OP(bit_and_assign, &=); | |
CUTE_FOLD_OP(bit_or_assign, |=); | |
CUTE_FOLD_OP(bit_xor_assign, ^=); | |
CUTE_FOLD_OP(left_shift_assign, <<=); | |
CUTE_FOLD_OP(right_shift_assign, >>=); | |
CUTE_FOLD_OP(logical_and, &&); | |
CUTE_FOLD_OP(logical_or, ||); | |
CUTE_FOLD_OP(equal_to, ==); | |
CUTE_FOLD_OP(not_equal_to, !=); | |
CUTE_FOLD_OP(greater, >); | |
CUTE_FOLD_OP(less, <); | |
CUTE_FOLD_OP(greater_equal, >=); | |
CUTE_FOLD_OP(less_equal, <=); | |
/**********/ | |
/** Meta **/ | |
/**********/ | |
template <class Fn, class Arg> | |
struct bound_fn { | |
template <class T> | |
CUTE_HOST_DEVICE constexpr | |
decltype(auto) | |
operator()(T&& arg) { | |
return fn_(arg_, static_cast<T&&>(arg)); | |
} | |
Fn fn_; | |
Arg arg_; | |
}; | |
template <class Fn, class Arg> | |
CUTE_HOST_DEVICE constexpr | |
auto | |
bind(Fn const& fn, Arg const& arg) { | |
return bound_fn<Fn,Arg>{fn, arg}; | |
} | |
} // end namespace cute | |