Spaces:
Sleeping
Sleeping
/*************************************************************************************************** | |
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
* SPDX-License-Identifier: BSD-3-Clause | |
* | |
* Redistribution and use in source and binary forms, with or without | |
* modification, are permitted provided that the following conditions are met: | |
* | |
* 1. Redistributions of source code must retain the above copyright notice, this | |
* list of conditions and the following disclaimer. | |
* | |
* 2. Redistributions in binary form must reproduce the above copyright notice, | |
* this list of conditions and the following disclaimer in the documentation | |
* and/or other materials provided with the distribution. | |
* | |
* 3. Neither the name of the copyright holder nor the names of its | |
* contributors may be used to endorse or promote products derived from | |
* this software without specific prior written permission. | |
* | |
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE | |
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE | |
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | |
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR | |
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER | |
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, | |
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE | |
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |
* | |
**************************************************************************************************/ | |
/** IntTuple is an integer or a tuple of IntTuples. | |
* This file holds utilities for working with IntTuples, | |
* but does not hold a concrete concept or class of IntTuple. | |
*/ | |
namespace cute | |
{ | |
// Implementation of get<0>(Integral). | |
// Even though is_tuple<Integral> is false and tuple_size<Integral> doesn't compile, | |
// CuTe defines rank(Integral) as 1, so it's useful for get<0>(Integral) to return its input | |
template <size_t I, class T, __CUTE_REQUIRES(cute::is_integral<cute::remove_cvref_t<T>>::value)> | |
CUTE_HOST_DEVICE constexpr | |
decltype(auto) | |
get(T&& t) noexcept | |
{ | |
static_assert(I == 0, "Index out of range"); | |
return static_cast<T&&>(t); | |
} | |
// Custom recursive get for anything that implements get<I>(.) (for a single integer I). | |
template <size_t I0, size_t I1, size_t... Is, class T> | |
CUTE_HOST_DEVICE constexpr | |
decltype(auto) | |
get(T&& t) noexcept | |
{ | |
return get<I1, Is...>(get<I0>(static_cast<T&&>(t))); | |
} | |
// | |
// rank | |
// | |
template <int... Is, class IntTuple> | |
CUTE_HOST_DEVICE constexpr | |
auto | |
rank(IntTuple const& t) | |
{ | |
if constexpr (sizeof...(Is) == 0) { | |
if constexpr (is_tuple<IntTuple>::value) { | |
return Int<tuple_size<IntTuple>::value>{}; | |
} else { | |
return Int<1>{}; | |
} | |
} else { | |
return rank(get<Is...>(t)); | |
} | |
CUTE_GCC_UNREACHABLE; | |
} | |
template <class IntTuple> | |
using rank_t = decltype(rank(declval<IntTuple>())); | |
template <class IntTuple> | |
static constexpr int rank_v = rank_t<IntTuple>::value; | |
// | |
// shape | |
// | |
template <class IntTuple> | |
CUTE_HOST_DEVICE constexpr | |
auto | |
shape(IntTuple const& s) | |
{ | |
if constexpr (is_tuple<IntTuple>::value) { | |
return transform(s, [](auto const& a) { return shape(a); }); | |
} else { | |
return s; | |
} | |
CUTE_GCC_UNREACHABLE; | |
} | |
template <int I, int... Is, class IntTuple> | |
CUTE_HOST_DEVICE constexpr | |
auto | |
shape(IntTuple const& s) | |
{ | |
if constexpr (is_tuple<IntTuple>::value) { | |
return shape<Is...>(get<I>(s)); | |
} else { | |
return get<I,Is...>(shape(s)); | |
} | |
CUTE_GCC_UNREACHABLE; | |
} | |
// | |
// max | |
// | |
template <class T0, class... Ts> | |
CUTE_HOST_DEVICE constexpr | |
auto | |
max(T0 const& t0, Ts const&... ts) | |
{ | |
if constexpr (is_tuple<T0>::value) { | |
return cute::max(cute::apply(t0, [](auto const&... a){ return cute::max(a...); }), ts...); | |
} else if constexpr (sizeof...(Ts) == 0) { | |
return t0; | |
} else { | |
return cute::max(t0, cute::max(ts...)); | |
} | |
CUTE_GCC_UNREACHABLE; | |
} | |
// | |
// min | |
// | |
template <class T0, class... Ts> | |
CUTE_HOST_DEVICE constexpr | |
auto | |
min(T0 const& t0, Ts const&... ts) | |
{ | |
if constexpr (is_tuple<T0>::value) { | |
return cute::min(cute::apply(t0, [](auto const&... a){ return cute::min(a...); }), ts...); | |
} else if constexpr (sizeof...(Ts) == 0) { | |
return t0; | |
} else { | |
return cute::min(t0, cute::min(ts...)); | |
} | |
CUTE_GCC_UNREACHABLE; | |
} | |
// | |
// gcd | |
// | |
template <class T0, class... Ts> | |
CUTE_HOST_DEVICE constexpr | |
auto | |
gcd(T0 const& t0, Ts const&... ts) | |
{ | |
if constexpr (is_tuple<T0>::value) { | |
return cute::gcd(cute::apply(t0, [](auto const&... a){ return cute::gcd(a...); }), ts...); | |
} else if constexpr (sizeof...(Ts) == 0) { | |
return t0; | |
} else { | |
return cute::gcd(t0, cute::gcd(ts...)); | |
} | |
CUTE_GCC_UNREACHABLE; | |
} | |
// | |
// depth | |
// | |
template <int... Is, class IntTuple> | |
CUTE_HOST_DEVICE constexpr | |
auto | |
depth(IntTuple const& t) | |
{ | |
if constexpr (sizeof...(Is) == 0) { | |
if constexpr (is_tuple<IntTuple>::value) { | |
return Int<1>{} + cute::apply(t, [](auto const&... v){ return cute::max(depth(v)...); }); | |
} else { | |
return Int<0>{}; | |
} | |
} else { | |
return depth(get<Is...>(t)); | |
} | |
CUTE_GCC_UNREACHABLE; | |
} | |
template <class Tuple> | |
using depth_t = decltype(depth(declval<Tuple>())); | |
template <class Tuple> | |
static constexpr int depth_v = depth_t<Tuple>::value; | |
// | |
// product | |
// | |
// Implementation of product as a function object | |
struct Product | |
{ | |
template <class IntTuple> | |
CUTE_HOST_DEVICE constexpr | |
auto | |
operator()(IntTuple const& a) const | |
{ | |
if constexpr (is_tuple<IntTuple>::value) { | |
if constexpr (tuple_size<IntTuple>::value == 0) { | |
return Int<1>{}; | |
} else { | |
return cute::transform_apply(a, Product{}, multiplies_unary_lfold{}); | |
} | |
} else if constexpr (cute::is_integral<IntTuple>::value) { | |
return a; | |
} | |
CUTE_GCC_UNREACHABLE; | |
} | |
}; | |
// Callable product function object | |
CUTE_INLINE_CONSTANT Product product; | |
// Return a rank(t) tuple @a result such that get<i>(@a result) = product(get<i>(@a t)) | |
template <class Tuple> | |
CUTE_HOST_DEVICE constexpr | |
auto | |
product_each(Tuple const& t) | |
{ | |
return transform(wrap(t), product); | |
} | |
// Take the product of Tuple at the leaves of TupleG | |
template <class Tuple, class TupleG> | |
CUTE_HOST_DEVICE constexpr | |
auto | |
product_like(Tuple const& tuple, TupleG const& guide) | |
{ | |
return transform_leaf(guide, tuple, [](auto const& g, auto const& t) { return product(t); }); | |
} | |
// Return the product of elements in a mode | |
template <int... Is, class IntTuple> | |
CUTE_HOST_DEVICE constexpr | |
auto | |
size(IntTuple const& a) | |
{ | |
if constexpr (sizeof...(Is) == 0) { | |
return product(a); | |
} else { | |
return size(get<Is...>(a)); | |
} | |
CUTE_GCC_UNREACHABLE; | |
} | |
template <class IntTuple> | |
static constexpr int size_v = decltype(size(declval<IntTuple>()))::value; | |
// | |
// sum | |
// | |
template <class IntTuple> | |
CUTE_HOST_DEVICE constexpr | |
auto | |
sum(IntTuple const& a) | |
{ | |
if constexpr (is_tuple<IntTuple>::value) { | |
return cute::apply(a, [](auto const&... v){ return (Int<0>{} + ... + sum(v)); }); | |
} else { | |
return a; | |
} | |
CUTE_GCC_UNREACHABLE; | |
} | |
// | |
// inner_product | |
// | |
template <class IntTupleA, class IntTupleB> | |
CUTE_HOST_DEVICE constexpr | |
auto | |
inner_product(IntTupleA const& a, IntTupleB const& b) | |
{ | |
if constexpr (is_tuple<IntTupleA>::value && is_tuple<IntTupleB>::value) { | |
static_assert(tuple_size<IntTupleA>::value == tuple_size<IntTupleB>::value, "Mismatched ranks"); | |
return transform_apply(a, b, [](auto const& x, auto const& y) { return inner_product(x,y); }, | |
[](auto const&... v) { return (Int<0>{} + ... + v); }); | |
} else { | |
return a * b; | |
} | |
CUTE_GCC_UNREACHABLE; | |
} | |
// | |
// ceil_div | |
// | |
template <class IntTupleA, class IntTupleB> | |
CUTE_HOST_DEVICE constexpr | |
auto | |
ceil_div(IntTupleA const& a, IntTupleB const& b) | |
{ | |
if constexpr (is_tuple<IntTupleA>::value) { | |
if constexpr (is_tuple<IntTupleB>::value) { // tuple tuple | |
static_assert(tuple_size<IntTupleA>::value >= tuple_size<IntTupleB>::value, "Mismatched ranks"); | |
constexpr int R = tuple_size<IntTupleA>::value; // Missing ranks in TupleB are implicitly 1 | |
return transform(a, append<R>(b,Int<1>{}), [](auto const& x, auto const& y) { return ceil_div(x,y); }); | |
} else { // tuple int | |
auto const [result, rest] = fold(a, cute::make_tuple(cute::make_tuple(), b), | |
[] (auto const& init, auto const& ai) { | |
return cute::make_tuple(append(get<0>(init), ceil_div(ai, get<1>(init))), ceil_div(get<1>(init), ai)); | |
}); | |
return result; | |
} | |
} else | |
if constexpr (is_tuple<IntTupleB>::value) { // int tuple | |
return ceil_div(a, product(b)); | |
} else { | |
return (a + b - Int<1>{}) / b; | |
} | |
CUTE_GCC_UNREACHABLE; | |
} | |
// | |
// round_up | |
// Round @a a up to the nearest multiple of @a b. | |
// For negative numbers, rounds away from zero. | |
// | |
template <class IntTupleA, class IntTupleB> | |
CUTE_HOST_DEVICE constexpr | |
auto | |
round_up(IntTupleA const& a, IntTupleB const& b) | |
{ | |
if constexpr (is_tuple<IntTupleA>::value && is_tuple<IntTupleB>::value) { | |
static_assert(tuple_size<IntTupleA>::value >= tuple_size<IntTupleB>::value, "Mismatched ranks"); | |
constexpr int R = tuple_size<IntTupleA>::value; // Missing ranks in TupleB are implicitly 1 | |
return transform(a, append<R>(b,Int<1>{}), [](auto const& x, auto const& y) { return round_up(x,y); }); | |
} else { | |
return ((a + b - Int<1>{}) / b) * b; | |
} | |
CUTE_GCC_UNREACHABLE; | |
} | |
/** Division for Shapes | |
* Case Tuple Tuple: | |
* Perform shape_div element-wise | |
* Case Tuple Int: | |
* Fold the division of b across each element of a | |
* Example: shape_div((4,5,6),40) -> shape_div((1,5,6),10) -> shape_div((1,1,6),2) -> (1,1,3) | |
* Case Int Tuple: | |
* Return shape_div(a, product(b)) | |
* Case Int Int: | |
* Enforce the divisibility condition a % b == 0 || b % a == 0 when possible | |
* Return a / b with rounding away from 0 (that is, 1 or -1 when a < b) | |
*/ | |
template <class IntTupleA, class IntTupleB> | |
CUTE_HOST_DEVICE constexpr | |
auto | |
shape_div(IntTupleA const& a, IntTupleB const& b) | |
{ | |
if constexpr (is_tuple<IntTupleA>::value) { | |
if constexpr (is_tuple<IntTupleB>::value) { // tuple tuple | |
static_assert(tuple_size<IntTupleA>::value == tuple_size<IntTupleB>::value, "Mismatched ranks"); | |
return transform(a, b, [](auto const& x, auto const& y) { return shape_div(x,y); }); | |
} else { // tuple int | |
auto const [result, rest] = fold(a, cute::make_tuple(cute::make_tuple(), b), | |
[] (auto const& init, auto const& ai) { | |
return cute::make_tuple(append(get<0>(init), shape_div(ai, get<1>(init))), shape_div(get<1>(init), ai)); | |
}); | |
return result; | |
} | |
} else | |
if constexpr (is_tuple<IntTupleB>::value) { // int tuple | |
return shape_div(a, product(b)); | |
} else | |
if constexpr (is_static<IntTupleA>::value && is_static<IntTupleB>::value) { | |
static_assert(IntTupleA::value % IntTupleB::value == 0 || IntTupleB::value % IntTupleA::value == 0, "Static shape_div failure"); | |
return C<shape_div(IntTupleA::value, IntTupleB::value)>{}; | |
} else { // int int | |
//assert(a % b == 0 || b % a == 0); // Waive dynamic assertion | |
return a / b != 0 ? a / b : signum(a) * signum(b); // Division with rounding away from zero | |
} | |
CUTE_GCC_UNREACHABLE; | |
} | |
/** Minimum for Shapes | |
*/ | |
template <class IntTupleA, class IntTupleB> | |
CUTE_HOST_DEVICE constexpr | |
auto | |
shape_min(IntTupleA const& a, IntTupleB const& b) | |
{ | |
if constexpr (is_tuple<IntTupleA>::value || is_tuple<IntTupleB>::value) { | |
static_assert(dependent_false<IntTupleA>, "Not implemented."); | |
} else | |
if constexpr (is_constant<1, IntTupleA>::value || is_constant<1, IntTupleB>::value) { | |
return Int<1>{}; // _1 is less than all other shapes, preserve static | |
} else { | |
return cute::min(a, b); | |
} | |
CUTE_GCC_UNREACHABLE; | |
} | |
/** Return a tuple the same profile as A scaled by corresponding elements in B | |
*/ | |
template <class A, class B> | |
CUTE_HOST_DEVICE constexpr | |
auto | |
elem_scale(A const& a, B const& b) | |
{ | |
if constexpr (is_tuple<A>::value) { | |
return transform(a, b, [](auto const& x, auto const& y) { return elem_scale(x,y); }); | |
} else { | |
return a * product(b); | |
} | |
CUTE_GCC_UNREACHABLE; | |
} | |
/** Test if two IntTuple have the same profile (hierarchical rank division) | |
*/ | |
template <class IntTupleA, class IntTupleB> | |
CUTE_HOST_DEVICE constexpr | |
auto | |
congruent(IntTupleA const& a, IntTupleB const& b) | |
{ | |
return bool_constant<is_same<decltype(repeat_like(shape(a),_0{})), | |
decltype(repeat_like(shape(b),_0{}))>::value>{}; | |
} | |
template <class A, class B> | |
using is_congruent = decltype(congruent(declval<A>(), declval<B>())); | |
/** Test if two IntTuple have the similar profiles up to Shape A (hierarchical rank division) | |
* weakly_congruent is a partial order on A and B: A <= B | |
*/ | |
template <class IntTupleA, class IntTupleB> | |
CUTE_HOST_DEVICE constexpr | |
auto | |
weakly_congruent(IntTupleA const& a, IntTupleB const& b) | |
{ | |
if constexpr (is_tuple<IntTupleA>::value && is_tuple<IntTupleB>::value) { | |
if constexpr (tuple_size<IntTupleA>::value != tuple_size<IntTupleB>::value) { | |
return false_type{}; | |
} else { | |
return transform_apply(a, b, [](auto const& x, auto const& y) { return weakly_congruent(x,y); }, | |
[](auto const&... z) { return (true_type{} && ... && z); }); | |
} | |
} else if constexpr (is_integral<IntTupleA>::value) { | |
return true_type{}; | |
} else if constexpr (is_integral<IntTupleB>::value) { | |
return false_type{}; | |
} else { | |
return weakly_congruent(shape(a), shape(b)); | |
} | |
CUTE_GCC_UNREACHABLE; | |
} | |
template <class A, class B> | |
using is_weakly_congruent = decltype(weakly_congruent(declval<A>(), declval<B>())); | |
/** Test if Shape A is compatible with Shape B: | |
* the size of A and B are the same, and | |
* any coordinate into A can also be used as a coordinate into B | |
* compatible is a partial order on A and B: A <= B | |
*/ | |
template <class IntTupleA, class IntTupleB> | |
CUTE_HOST_DEVICE constexpr | |
auto | |
compatible(IntTupleA const& a, IntTupleB const& b) | |
{ | |
if constexpr (is_tuple<IntTupleA>::value && is_tuple<IntTupleB>::value) { | |
if constexpr (tuple_size<IntTupleA>::value != tuple_size<IntTupleB>::value) { | |
return false_type{}; | |
} else { | |
return transform_apply(a, b, [](auto const& x, auto const& y) { return compatible(x,y); }, | |
[](auto const&... z) { return (true_type{} && ... && z); }); | |
} | |
} else if constexpr (is_integral<IntTupleA>::value) { | |
return a == size(b); | |
} else if constexpr (is_integral<IntTupleB>::value) { | |
return false_type{}; | |
} else { | |
return compatible(shape(a), shape(b)); | |
} | |
CUTE_GCC_UNREACHABLE; | |
} | |
template <class A, class B> | |
using is_compatible = decltype(compatible(declval<A>(), declval<B>())); | |
/** Test if Shape A is weakly compatible with Shape B: | |
* there exists a Shape C congruent to A such that compatible(elem_scale(A,C), B) | |
* weakly_compatible is a partial order on A and B: A <= B | |
*/ | |
template <class IntTupleA, class IntTupleB> | |
CUTE_HOST_DEVICE constexpr | |
auto | |
weakly_compatible(IntTupleA const& a, IntTupleB const& b) | |
{ | |
if constexpr (is_tuple<IntTupleA>::value && is_tuple<IntTupleB>::value) { | |
if constexpr (tuple_size<IntTupleA>::value != tuple_size<IntTupleB>::value) { | |
return false_type{}; | |
} else { | |
return transform_apply(a, b, [](auto const& x, auto const& y) { return weakly_compatible(x,y); }, | |
[](auto const&... z) { return (true_type{} && ... && z); }); | |
} | |
} else if constexpr (is_integral<IntTupleA>::value) { | |
return size(b) % a == Int<0>{}; | |
} else if constexpr (is_integral<IntTupleB>::value) { | |
return false_type{}; | |
} else { | |
return weakly_compatible(shape(a), shape(b)); | |
} | |
CUTE_GCC_UNREACHABLE; | |
} | |
template <class A, class B> | |
using is_weakly_compatible = decltype(weakly_compatible(declval<A>(), declval<B>())); | |
/** Replace the elements of Tuple B that are paired with an Int<0> with an Int<1> | |
*/ | |
template <class IntTupleA, class IntTupleB> | |
CUTE_HOST_DEVICE constexpr | |
auto | |
filter_zeros(IntTupleA const& a, IntTupleB const& b) | |
{ | |
if constexpr (is_tuple<IntTupleA>::value) { | |
return transform(a, b, [](auto const& x, auto const& y) { return filter_zeros(x,y); }); | |
} else if constexpr (is_constant<0, IntTupleA>::value) { | |
return Int<1>{}; | |
} else { | |
return b; | |
} | |
CUTE_GCC_UNREACHABLE; | |
} | |
template <class Tuple> | |
CUTE_HOST_DEVICE constexpr | |
auto | |
filter_zeros(Tuple const& t) | |
{ | |
return filter_zeros(t, t); | |
} | |
// | |
// Converters and constructors with arrays and params | |
// | |
/** Make an IntTuple of rank N from an Indexable array. | |
* Access elements up to a dynamic index n, then use init (requires compatible types) | |
* Consider cute::take<B,E> if all indexing is known to be valid | |
* \code | |
* std::vector<int> a = {6,3,4}; | |
* auto tup = make_int_tuple<5>(a, a.size(), 0) // (6,3,4,0,0) | |
* \endcode | |
*/ | |
template <int N, class Indexable, class T> | |
CUTE_HOST_DEVICE constexpr | |
auto | |
make_int_tuple(Indexable const& t, int n, T const& init) | |
{ | |
static_assert(N > 0); | |
if constexpr (N == 1) { | |
return 0 < n ? t[0] : init; | |
} else { | |
return transform(make_seq<N>{}, [&](auto i) { return i < n ? t[i] : init; }); | |
} | |
CUTE_GCC_UNREACHABLE; | |
} | |
/** Fill the dynamic values of a Tuple with values from another Tuple | |
* \code | |
* auto params = make_tuple(6,3,4); | |
* cute::tuple<Int<1>, cute::tuple<int, int, Int<3>>, int, Int<2>> result; | |
* fill_int_tuple_from(result, params); // (_1,(6,3,_3),4,_2) | |
* \endcode | |
*/ | |
template <class Tuple, class TupleV> | |
CUTE_HOST_DEVICE constexpr | |
auto | |
fill_int_tuple_from(Tuple& result, TupleV const& vals) | |
{ | |
return fold(result, vals, [](auto const& init, auto&& r) { | |
if constexpr (is_static<remove_cvref_t<decltype(r)>>::value) { // Skip static elements of result | |
return init; | |
} else if constexpr (is_tuple<remove_cvref_t<decltype(r)>>::value) { // Recurse into tuples | |
return fill_int_tuple_from(r, init); | |
} else { // Assign and consume arg | |
static_assert(tuple_size<remove_cvref_t<decltype(init)>>::value > 0, "Not enough values to fill with!"); | |
r = get<0>(init); | |
return remove<0>(init); | |
} | |
CUTE_GCC_UNREACHABLE; | |
}); | |
} | |
/** Make a "Tuple" by filling in the dynamic values in order from the arguments | |
* \code | |
* using result_t = cute::tuple<Int<1>, cute::tuple<int, int, Int<3>>, int, Int<2>>; | |
* auto result = make_int_tuple_from<result_t>(6,3,4); // (_1,(6,3,_3),4,_2) | |
* \endcode | |
*/ | |
template <class Tuple, class... Ts> | |
CUTE_HOST_DEVICE constexpr | |
Tuple | |
make_int_tuple_from(Ts const&... ts) | |
{ | |
Tuple result = Tuple{}; | |
fill_int_tuple_from(result, cute::make_tuple(ts...)); | |
return result; | |
} | |
/** Convert a tuple to a flat homogeneous array of type T | |
* \code | |
* auto tup = cute::make_tuple(Int<1>{}, cute::make_tuple(6,3,Int<3>{}),4,Int<2>{}); | |
* cute::array<uint64_t,6> result = to_array<uint64_t>(tup); // [1,6,3,3,4,2] | |
* \endcode | |
*/ | |
template <class T = int64_t, class IntTuple> | |
CUTE_HOST_DEVICE constexpr | |
auto | |
to_array(IntTuple const& t) | |
{ | |
auto flat_t = flatten_to_tuple(t); | |
constexpr int N = tuple_size<decltype(flat_t)>::value; | |
cute::array<T,N> result; | |
for_each(make_seq<N>{}, [&] (auto i) { result[i] = get<i>(flat_t); }); | |
return result; | |
} | |
// | |
// Comparison operators | |
// | |
// | |
// There are many ways to compare tuple of elements and because CuTe is built | |
// on parameterizing layouts of coordinates, some comparisons are appropriate | |
// only in certain cases. | |
// -- lexicographical comparison [reverse, reflected, revref] : Correct for coords in RowMajor Layout | |
// -- colexicographical comparison [reverse, reflected, revref] : Correct for coords in ColMajor Layout | |
// -- element-wise comparison [any,all] : | |
// This can be very confusing. To avoid errors in selecting the appropriate | |
// comparison, op<|op<=|op>|op>= are *not* implemented for cute::tuple. | |
// | |
// When actually desiring to order coordinates, the user should map them to | |
// their indices within the Layout they came from: | |
// e.g. layoutX(coordA) < layoutX(coordB) | |
// That said, we implement the three most common ways to compare tuples below. | |
// These are implemented with slighly more explicit names than op<. | |
// | |
template <class IntTupleA, class IntTupleB> | |
CUTE_HOST_DEVICE constexpr | |
auto | |
lex_less(IntTupleA const& a, IntTupleB const& b); | |
template <class IntTupleA, class IntTupleB> | |
CUTE_HOST_DEVICE constexpr | |
auto | |
colex_less(IntTupleA const& a, IntTupleB const& b); | |
template <class IntTupleA, class IntTupleB> | |
CUTE_HOST_DEVICE constexpr | |
auto | |
elem_less(IntTupleA const& a, IntTupleB const& b); | |
namespace detail { | |
template <size_t I, class TupleA, class TupleB> | |
CUTE_HOST_DEVICE constexpr | |
auto | |
lex_less_impl(TupleA const& a, TupleB const& b) | |
{ | |
if constexpr (I == tuple_size<TupleB>::value) { | |
return cute::false_type{}; // Terminal: TupleB is exhausted | |
} else if constexpr (I == tuple_size<TupleA>::value) { | |
return cute::true_type{}; // Terminal: TupleA is exhausted, TupleB is not exhausted | |
} else { | |
return lex_less(get<I>(a), get<I>(b)) || (get<I>(a) == get<I>(b) && lex_less_impl<I+1>(a,b)); | |
} | |
CUTE_GCC_UNREACHABLE; | |
} | |
template <size_t I, class TupleA, class TupleB> | |
CUTE_HOST_DEVICE constexpr | |
auto | |
colex_less_impl(TupleA const& a, TupleB const& b) | |
{ | |
if constexpr (I == tuple_size<TupleB>::value) { | |
return cute::false_type{}; // Terminal: TupleB is exhausted | |
} else if constexpr (I == tuple_size<TupleA>::value) { | |
return cute::true_type{}; // Terminal: TupleA is exhausted, TupleB is not exhausted | |
} else { | |
constexpr size_t A = tuple_size<TupleA>::value - 1 - I; | |
constexpr size_t B = tuple_size<TupleB>::value - 1 - I; | |
return colex_less(get<A>(a), get<B>(b)) || (get<A>(a) == get<B>(b) && colex_less_impl<I+1>(a,b)); | |
} | |
CUTE_GCC_UNREACHABLE; | |
} | |
template <size_t I, class TupleA, class TupleB> | |
CUTE_HOST_DEVICE constexpr | |
auto | |
elem_less_impl(TupleA const& a, TupleB const& b) | |
{ | |
if constexpr (I == tuple_size<TupleA>::value) { | |
return cute::true_type{}; // Terminal: TupleA is exhausted | |
} else if constexpr (I == tuple_size<TupleB>::value) { | |
return cute::false_type{}; // Terminal: TupleA is not exhausted, TupleB is exhausted | |
} else { | |
return elem_less(get<I>(a), get<I>(b)) && elem_less_impl<I+1>(a,b); | |
} | |
CUTE_GCC_UNREACHABLE; | |
} | |
} // end namespace detail | |
// Lexicographical comparison | |
template <class IntTupleA, class IntTupleB> | |
CUTE_HOST_DEVICE constexpr | |
auto | |
lex_less(IntTupleA const& a, IntTupleB const& b) | |
{ | |
if constexpr (is_tuple<IntTupleA>::value && is_tuple<IntTupleB>::value) { | |
return detail::lex_less_impl<0>(a, b); | |
} else { | |
return a < b; | |
} | |
CUTE_GCC_UNREACHABLE; | |
} | |
template <class T, class U> | |
CUTE_HOST_DEVICE constexpr | |
auto | |
lex_leq(T const& t, U const& u) { | |
return !lex_less(u, t); | |
} | |
template <class T, class U> | |
CUTE_HOST_DEVICE constexpr | |
auto | |
lex_gtr(T const& t, U const& u) { | |
return lex_less(u, t); | |
} | |
template <class T, class U> | |
CUTE_HOST_DEVICE constexpr | |
auto | |
lex_geq(T const& t, U const& u) { | |
return !lex_less(t, u); | |
} | |
// Colexicographical comparison | |
template <class IntTupleA, class IntTupleB> | |
CUTE_HOST_DEVICE constexpr | |
auto | |
colex_less(IntTupleA const& a, IntTupleB const& b) | |
{ | |
if constexpr (is_tuple<IntTupleA>::value && is_tuple<IntTupleB>::value) { | |
return detail::colex_less_impl<0>(a, b); | |
} else { | |
return a < b; | |
} | |
CUTE_GCC_UNREACHABLE; | |
} | |
template <class T, class U> | |
CUTE_HOST_DEVICE constexpr | |
auto | |
colex_leq(T const& t, U const& u) { | |
return !colex_less(u, t); | |
} | |
template <class T, class U> | |
CUTE_HOST_DEVICE constexpr | |
auto | |
colex_gtr(T const& t, U const& u) { | |
return colex_less(u, t); | |
} | |
template <class T, class U> | |
CUTE_HOST_DEVICE constexpr | |
auto | |
colex_geq(T const& t, U const& u) { | |
return !colex_less(t, u); | |
} | |
// Elementwise [all] comparison | |
template <class IntTupleA, class IntTupleB> | |
CUTE_HOST_DEVICE constexpr | |
auto | |
elem_less(IntTupleA const& a, IntTupleB const& b) | |
{ | |
if constexpr (is_tuple<IntTupleA>::value && is_tuple<IntTupleB>::value) { | |
return detail::elem_less_impl<0>(a, b); | |
} else { | |
return a < b; | |
} | |
CUTE_GCC_UNREACHABLE; | |
} | |
template <class T, class U> | |
CUTE_HOST_DEVICE constexpr | |
auto | |
elem_leq(T const& t, U const& u) { | |
return !elem_less(u, t); | |
} | |
template <class T, class U> | |
CUTE_HOST_DEVICE constexpr | |
auto | |
elem_gtr(T const& t, U const& u) { | |
return elem_less(u, t); | |
} | |
template <class T, class U> | |
CUTE_HOST_DEVICE constexpr | |
auto | |
elem_geq(T const& t, U const& u) { | |
return !elem_less(t, u); | |
} | |
namespace detail { | |
/** Increment a (dynamic) coord lexicographically within a shape | |
* @pre is_congruent<Coord,Shape>::value | |
* \code | |
* auto shape = make_shape(1,2,make_shape(2,3),3); | |
* | |
* int i = 0; | |
* for (auto coord = repeat_like(shape, 0); back(coord) != back(shape); increment(coord, shape)) { | |
* std::cout << i++ << ": " << coord << std::endl; | |
* } | |
* assert(i == size(shape)); | |
* \endcode | |
*/ | |
template <int I = 0, class Coord, class Shape> | |
CUTE_HOST_DEVICE constexpr | |
void | |
increment(Coord& coord, Shape const& shape) | |
{ | |
if constexpr (is_integral<Coord>::value) { | |
++coord; | |
} else { | |
increment(get<I>(coord), get<I>(shape)); | |
if constexpr (I+1 < tuple_size<Coord>::value) { | |
if (back(get<I>(coord)) == back(get<I>(shape))) { | |
back(get<I>(coord)) = 0; | |
increment<I+1>(coord, shape); | |
} | |
} | |
} | |
} | |
} // end namespace detail | |
struct ForwardCoordIteratorSentinal | |
{}; | |
// A forward iterator for a starting coordinate in a shape's domain, and a shape. | |
// The starting coordinate may be zero but need not necessarily be. | |
template <class Coord, class Shape> | |
struct ForwardCoordIterator | |
{ | |
static_assert(is_congruent<Coord, Shape>::value); | |
CUTE_HOST_DEVICE constexpr | |
Coord const& operator*() const { return coord; } | |
CUTE_HOST_DEVICE constexpr | |
ForwardCoordIterator& operator++() { detail::increment(coord, shape); return *this; } | |
// Sentinel for the end of the implied range | |
CUTE_HOST_DEVICE constexpr | |
bool operator< (ForwardCoordIteratorSentinal const&) const { return back(coord) < back(shape); } | |
CUTE_HOST_DEVICE constexpr | |
bool operator==(ForwardCoordIteratorSentinal const&) const { return back(coord) == back(shape); } | |
CUTE_HOST_DEVICE constexpr | |
bool operator!=(ForwardCoordIteratorSentinal const&) const { return back(coord) != back(shape); } | |
// NOTE: These are expensive, avoid use | |
CUTE_HOST_DEVICE constexpr | |
bool operator< (ForwardCoordIterator const& other) const { return colex_less(coord, other.coord); } | |
CUTE_HOST_DEVICE constexpr | |
bool operator==(ForwardCoordIterator const& other) const { return coord == other.coord; } | |
CUTE_HOST_DEVICE constexpr | |
bool operator!=(ForwardCoordIterator const& other) const { return coord != other.coord; } | |
Coord coord; | |
Shape const& shape; | |
}; | |
// A forward iterator for a coordinate that starts from a provided coordinate | |
template <class Shape, class Coord> | |
CUTE_HOST_DEVICE constexpr | |
auto | |
make_coord_iterator(Coord const& coord, Shape const& shape) | |
{ | |
return ForwardCoordIterator<Coord,Shape>{coord,shape}; | |
} | |
// A forward iterator for a coordinate that starts from zero | |
template <class Shape> | |
CUTE_HOST_DEVICE constexpr | |
auto | |
make_coord_iterator(Shape const& shape) | |
{ | |
auto coord = repeat_like(shape, int(0)); | |
return make_coord_iterator(coord, shape); | |
} | |
} // end namespace cute | |