Do0rMaMu's picture
Upload folder using huggingface_hub
e45d058 verified
/***************************************************************************************************
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include <cute/config.hpp>
#include <cute/container/tuple.hpp>
#include <cute/container/array.hpp>
#include <cute/algorithm/tuple_algorithms.hpp>
#include <cute/numeric/integral_constant.hpp>
/** IntTuple is an integer or a tuple of IntTuples.
* This file holds utilities for working with IntTuples,
* but does not hold a concrete concept or class of IntTuple.
*/
namespace cute
{
// Implementation of get<0>(Integral).
// Even though is_tuple<Integral> is false and tuple_size<Integral> doesn't compile,
// CuTe defines rank(Integral) as 1, so it's useful for get<0>(Integral) to return its input
template <size_t I, class T, __CUTE_REQUIRES(cute::is_integral<cute::remove_cvref_t<T>>::value)>
CUTE_HOST_DEVICE constexpr
decltype(auto)
get(T&& t) noexcept
{
static_assert(I == 0, "Index out of range");
return static_cast<T&&>(t);
}
// Custom recursive get for anything that implements get<I>(.) (for a single integer I).
template <size_t I0, size_t I1, size_t... Is, class T>
CUTE_HOST_DEVICE constexpr
decltype(auto)
get(T&& t) noexcept
{
return get<I1, Is...>(get<I0>(static_cast<T&&>(t)));
}
//
// rank
//
template <int... Is, class IntTuple>
CUTE_HOST_DEVICE constexpr
auto
rank(IntTuple const& t)
{
if constexpr (sizeof...(Is) == 0) {
if constexpr (is_tuple<IntTuple>::value) {
return Int<tuple_size<IntTuple>::value>{};
} else {
return Int<1>{};
}
} else {
return rank(get<Is...>(t));
}
CUTE_GCC_UNREACHABLE;
}
template <class IntTuple>
using rank_t = decltype(rank(declval<IntTuple>()));
template <class IntTuple>
static constexpr int rank_v = rank_t<IntTuple>::value;
//
// shape
//
template <class IntTuple>
CUTE_HOST_DEVICE constexpr
auto
shape(IntTuple const& s)
{
if constexpr (is_tuple<IntTuple>::value) {
return transform(s, [](auto const& a) { return shape(a); });
} else {
return s;
}
CUTE_GCC_UNREACHABLE;
}
template <int I, int... Is, class IntTuple>
CUTE_HOST_DEVICE constexpr
auto
shape(IntTuple const& s)
{
if constexpr (is_tuple<IntTuple>::value) {
return shape<Is...>(get<I>(s));
} else {
return get<I,Is...>(shape(s));
}
CUTE_GCC_UNREACHABLE;
}
//
// max
//
template <class T0, class... Ts>
CUTE_HOST_DEVICE constexpr
auto
max(T0 const& t0, Ts const&... ts)
{
if constexpr (is_tuple<T0>::value) {
return cute::max(cute::apply(t0, [](auto const&... a){ return cute::max(a...); }), ts...);
} else if constexpr (sizeof...(Ts) == 0) {
return t0;
} else {
return cute::max(t0, cute::max(ts...));
}
CUTE_GCC_UNREACHABLE;
}
//
// min
//
template <class T0, class... Ts>
CUTE_HOST_DEVICE constexpr
auto
min(T0 const& t0, Ts const&... ts)
{
if constexpr (is_tuple<T0>::value) {
return cute::min(cute::apply(t0, [](auto const&... a){ return cute::min(a...); }), ts...);
} else if constexpr (sizeof...(Ts) == 0) {
return t0;
} else {
return cute::min(t0, cute::min(ts...));
}
CUTE_GCC_UNREACHABLE;
}
//
// gcd
//
template <class T0, class... Ts>
CUTE_HOST_DEVICE constexpr
auto
gcd(T0 const& t0, Ts const&... ts)
{
if constexpr (is_tuple<T0>::value) {
return cute::gcd(cute::apply(t0, [](auto const&... a){ return cute::gcd(a...); }), ts...);
} else if constexpr (sizeof...(Ts) == 0) {
return t0;
} else {
return cute::gcd(t0, cute::gcd(ts...));
}
CUTE_GCC_UNREACHABLE;
}
//
// depth
//
template <int... Is, class IntTuple>
CUTE_HOST_DEVICE constexpr
auto
depth(IntTuple const& t)
{
if constexpr (sizeof...(Is) == 0) {
if constexpr (is_tuple<IntTuple>::value) {
return Int<1>{} + cute::apply(t, [](auto const&... v){ return cute::max(depth(v)...); });
} else {
return Int<0>{};
}
} else {
return depth(get<Is...>(t));
}
CUTE_GCC_UNREACHABLE;
}
template <class Tuple>
using depth_t = decltype(depth(declval<Tuple>()));
template <class Tuple>
static constexpr int depth_v = depth_t<Tuple>::value;
//
// product
//
// Implementation of product as a function object
struct Product
{
template <class IntTuple>
CUTE_HOST_DEVICE constexpr
auto
operator()(IntTuple const& a) const
{
if constexpr (is_tuple<IntTuple>::value) {
if constexpr (tuple_size<IntTuple>::value == 0) {
return Int<1>{};
} else {
return cute::transform_apply(a, Product{}, multiplies_unary_lfold{});
}
} else if constexpr (cute::is_integral<IntTuple>::value) {
return a;
}
CUTE_GCC_UNREACHABLE;
}
};
// Callable product function object
CUTE_INLINE_CONSTANT Product product;
// Return a rank(t) tuple @a result such that get<i>(@a result) = product(get<i>(@a t))
template <class Tuple>
CUTE_HOST_DEVICE constexpr
auto
product_each(Tuple const& t)
{
return transform(wrap(t), product);
}
// Take the product of Tuple at the leaves of TupleG
template <class Tuple, class TupleG>
CUTE_HOST_DEVICE constexpr
auto
product_like(Tuple const& tuple, TupleG const& guide)
{
return transform_leaf(guide, tuple, [](auto const& g, auto const& t) { return product(t); });
}
// Return the product of elements in a mode
template <int... Is, class IntTuple>
CUTE_HOST_DEVICE constexpr
auto
size(IntTuple const& a)
{
if constexpr (sizeof...(Is) == 0) {
return product(a);
} else {
return size(get<Is...>(a));
}
CUTE_GCC_UNREACHABLE;
}
template <class IntTuple>
static constexpr int size_v = decltype(size(declval<IntTuple>()))::value;
//
// sum
//
template <class IntTuple>
CUTE_HOST_DEVICE constexpr
auto
sum(IntTuple const& a)
{
if constexpr (is_tuple<IntTuple>::value) {
return cute::apply(a, [](auto const&... v){ return (Int<0>{} + ... + sum(v)); });
} else {
return a;
}
CUTE_GCC_UNREACHABLE;
}
//
// inner_product
//
template <class IntTupleA, class IntTupleB>
CUTE_HOST_DEVICE constexpr
auto
inner_product(IntTupleA const& a, IntTupleB const& b)
{
if constexpr (is_tuple<IntTupleA>::value && is_tuple<IntTupleB>::value) {
static_assert(tuple_size<IntTupleA>::value == tuple_size<IntTupleB>::value, "Mismatched ranks");
return transform_apply(a, b, [](auto const& x, auto const& y) { return inner_product(x,y); },
[](auto const&... v) { return (Int<0>{} + ... + v); });
} else {
return a * b;
}
CUTE_GCC_UNREACHABLE;
}
//
// ceil_div
//
template <class IntTupleA, class IntTupleB>
CUTE_HOST_DEVICE constexpr
auto
ceil_div(IntTupleA const& a, IntTupleB const& b)
{
if constexpr (is_tuple<IntTupleA>::value) {
if constexpr (is_tuple<IntTupleB>::value) { // tuple tuple
static_assert(tuple_size<IntTupleA>::value >= tuple_size<IntTupleB>::value, "Mismatched ranks");
constexpr int R = tuple_size<IntTupleA>::value; // Missing ranks in TupleB are implicitly 1
return transform(a, append<R>(b,Int<1>{}), [](auto const& x, auto const& y) { return ceil_div(x,y); });
} else { // tuple int
auto const [result, rest] = fold(a, cute::make_tuple(cute::make_tuple(), b),
[] (auto const& init, auto const& ai) {
return cute::make_tuple(append(get<0>(init), ceil_div(ai, get<1>(init))), ceil_div(get<1>(init), ai));
});
return result;
}
} else
if constexpr (is_tuple<IntTupleB>::value) { // int tuple
return ceil_div(a, product(b));
} else {
return (a + b - Int<1>{}) / b;
}
CUTE_GCC_UNREACHABLE;
}
//
// round_up
// Round @a a up to the nearest multiple of @a b.
// For negative numbers, rounds away from zero.
//
template <class IntTupleA, class IntTupleB>
CUTE_HOST_DEVICE constexpr
auto
round_up(IntTupleA const& a, IntTupleB const& b)
{
if constexpr (is_tuple<IntTupleA>::value && is_tuple<IntTupleB>::value) {
static_assert(tuple_size<IntTupleA>::value >= tuple_size<IntTupleB>::value, "Mismatched ranks");
constexpr int R = tuple_size<IntTupleA>::value; // Missing ranks in TupleB are implicitly 1
return transform(a, append<R>(b,Int<1>{}), [](auto const& x, auto const& y) { return round_up(x,y); });
} else {
return ((a + b - Int<1>{}) / b) * b;
}
CUTE_GCC_UNREACHABLE;
}
/** Division for Shapes
* Case Tuple Tuple:
* Perform shape_div element-wise
* Case Tuple Int:
* Fold the division of b across each element of a
* Example: shape_div((4,5,6),40) -> shape_div((1,5,6),10) -> shape_div((1,1,6),2) -> (1,1,3)
* Case Int Tuple:
* Return shape_div(a, product(b))
* Case Int Int:
* Enforce the divisibility condition a % b == 0 || b % a == 0 when possible
* Return a / b with rounding away from 0 (that is, 1 or -1 when a < b)
*/
template <class IntTupleA, class IntTupleB>
CUTE_HOST_DEVICE constexpr
auto
shape_div(IntTupleA const& a, IntTupleB const& b)
{
if constexpr (is_tuple<IntTupleA>::value) {
if constexpr (is_tuple<IntTupleB>::value) { // tuple tuple
static_assert(tuple_size<IntTupleA>::value == tuple_size<IntTupleB>::value, "Mismatched ranks");
return transform(a, b, [](auto const& x, auto const& y) { return shape_div(x,y); });
} else { // tuple int
auto const [result, rest] = fold(a, cute::make_tuple(cute::make_tuple(), b),
[] (auto const& init, auto const& ai) {
return cute::make_tuple(append(get<0>(init), shape_div(ai, get<1>(init))), shape_div(get<1>(init), ai));
});
return result;
}
} else
if constexpr (is_tuple<IntTupleB>::value) { // int tuple
return shape_div(a, product(b));
} else
if constexpr (is_static<IntTupleA>::value && is_static<IntTupleB>::value) {
static_assert(IntTupleA::value % IntTupleB::value == 0 || IntTupleB::value % IntTupleA::value == 0, "Static shape_div failure");
return C<shape_div(IntTupleA::value, IntTupleB::value)>{};
} else { // int int
//assert(a % b == 0 || b % a == 0); // Waive dynamic assertion
return a / b != 0 ? a / b : signum(a) * signum(b); // Division with rounding away from zero
}
CUTE_GCC_UNREACHABLE;
}
/** Minimum for Shapes
*/
template <class IntTupleA, class IntTupleB>
CUTE_HOST_DEVICE constexpr
auto
shape_min(IntTupleA const& a, IntTupleB const& b)
{
if constexpr (is_tuple<IntTupleA>::value || is_tuple<IntTupleB>::value) {
static_assert(dependent_false<IntTupleA>, "Not implemented.");
} else
if constexpr (is_constant<1, IntTupleA>::value || is_constant<1, IntTupleB>::value) {
return Int<1>{}; // _1 is less than all other shapes, preserve static
} else {
return cute::min(a, b);
}
CUTE_GCC_UNREACHABLE;
}
/** Return a tuple the same profile as A scaled by corresponding elements in B
*/
template <class A, class B>
CUTE_HOST_DEVICE constexpr
auto
elem_scale(A const& a, B const& b)
{
if constexpr (is_tuple<A>::value) {
return transform(a, b, [](auto const& x, auto const& y) { return elem_scale(x,y); });
} else {
return a * product(b);
}
CUTE_GCC_UNREACHABLE;
}
/** Test if two IntTuple have the same profile (hierarchical rank division)
*/
template <class IntTupleA, class IntTupleB>
CUTE_HOST_DEVICE constexpr
auto
congruent(IntTupleA const& a, IntTupleB const& b)
{
return bool_constant<is_same<decltype(repeat_like(shape(a),_0{})),
decltype(repeat_like(shape(b),_0{}))>::value>{};
}
template <class A, class B>
using is_congruent = decltype(congruent(declval<A>(), declval<B>()));
/** Test if two IntTuple have the similar profiles up to Shape A (hierarchical rank division)
* weakly_congruent is a partial order on A and B: A <= B
*/
template <class IntTupleA, class IntTupleB>
CUTE_HOST_DEVICE constexpr
auto
weakly_congruent(IntTupleA const& a, IntTupleB const& b)
{
if constexpr (is_tuple<IntTupleA>::value && is_tuple<IntTupleB>::value) {
if constexpr (tuple_size<IntTupleA>::value != tuple_size<IntTupleB>::value) {
return false_type{};
} else {
return transform_apply(a, b, [](auto const& x, auto const& y) { return weakly_congruent(x,y); },
[](auto const&... z) { return (true_type{} && ... && z); });
}
} else if constexpr (is_integral<IntTupleA>::value) {
return true_type{};
} else if constexpr (is_integral<IntTupleB>::value) {
return false_type{};
} else {
return weakly_congruent(shape(a), shape(b));
}
CUTE_GCC_UNREACHABLE;
}
template <class A, class B>
using is_weakly_congruent = decltype(weakly_congruent(declval<A>(), declval<B>()));
/** Test if Shape A is compatible with Shape B:
* the size of A and B are the same, and
* any coordinate into A can also be used as a coordinate into B
* compatible is a partial order on A and B: A <= B
*/
template <class IntTupleA, class IntTupleB>
CUTE_HOST_DEVICE constexpr
auto
compatible(IntTupleA const& a, IntTupleB const& b)
{
if constexpr (is_tuple<IntTupleA>::value && is_tuple<IntTupleB>::value) {
if constexpr (tuple_size<IntTupleA>::value != tuple_size<IntTupleB>::value) {
return false_type{};
} else {
return transform_apply(a, b, [](auto const& x, auto const& y) { return compatible(x,y); },
[](auto const&... z) { return (true_type{} && ... && z); });
}
} else if constexpr (is_integral<IntTupleA>::value) {
return a == size(b);
} else if constexpr (is_integral<IntTupleB>::value) {
return false_type{};
} else {
return compatible(shape(a), shape(b));
}
CUTE_GCC_UNREACHABLE;
}
template <class A, class B>
using is_compatible = decltype(compatible(declval<A>(), declval<B>()));
/** Test if Shape A is weakly compatible with Shape B:
* there exists a Shape C congruent to A such that compatible(elem_scale(A,C), B)
* weakly_compatible is a partial order on A and B: A <= B
*/
template <class IntTupleA, class IntTupleB>
CUTE_HOST_DEVICE constexpr
auto
weakly_compatible(IntTupleA const& a, IntTupleB const& b)
{
if constexpr (is_tuple<IntTupleA>::value && is_tuple<IntTupleB>::value) {
if constexpr (tuple_size<IntTupleA>::value != tuple_size<IntTupleB>::value) {
return false_type{};
} else {
return transform_apply(a, b, [](auto const& x, auto const& y) { return weakly_compatible(x,y); },
[](auto const&... z) { return (true_type{} && ... && z); });
}
} else if constexpr (is_integral<IntTupleA>::value) {
return size(b) % a == Int<0>{};
} else if constexpr (is_integral<IntTupleB>::value) {
return false_type{};
} else {
return weakly_compatible(shape(a), shape(b));
}
CUTE_GCC_UNREACHABLE;
}
template <class A, class B>
using is_weakly_compatible = decltype(weakly_compatible(declval<A>(), declval<B>()));
/** Replace the elements of Tuple B that are paired with an Int<0> with an Int<1>
*/
template <class IntTupleA, class IntTupleB>
CUTE_HOST_DEVICE constexpr
auto
filter_zeros(IntTupleA const& a, IntTupleB const& b)
{
if constexpr (is_tuple<IntTupleA>::value) {
return transform(a, b, [](auto const& x, auto const& y) { return filter_zeros(x,y); });
} else if constexpr (is_constant<0, IntTupleA>::value) {
return Int<1>{};
} else {
return b;
}
CUTE_GCC_UNREACHABLE;
}
template <class Tuple>
CUTE_HOST_DEVICE constexpr
auto
filter_zeros(Tuple const& t)
{
return filter_zeros(t, t);
}
//
// Converters and constructors with arrays and params
//
/** Make an IntTuple of rank N from an Indexable array.
* Access elements up to a dynamic index n, then use init (requires compatible types)
* Consider cute::take<B,E> if all indexing is known to be valid
* \code
* std::vector<int> a = {6,3,4};
* auto tup = make_int_tuple<5>(a, a.size(), 0) // (6,3,4,0,0)
* \endcode
*/
template <int N, class Indexable, class T>
CUTE_HOST_DEVICE constexpr
auto
make_int_tuple(Indexable const& t, int n, T const& init)
{
static_assert(N > 0);
if constexpr (N == 1) {
return 0 < n ? t[0] : init;
} else {
return transform(make_seq<N>{}, [&](auto i) { return i < n ? t[i] : init; });
}
CUTE_GCC_UNREACHABLE;
}
/** Fill the dynamic values of a Tuple with values from another Tuple
* \code
* auto params = make_tuple(6,3,4);
* cute::tuple<Int<1>, cute::tuple<int, int, Int<3>>, int, Int<2>> result;
* fill_int_tuple_from(result, params); // (_1,(6,3,_3),4,_2)
* \endcode
*/
template <class Tuple, class TupleV>
CUTE_HOST_DEVICE constexpr
auto
fill_int_tuple_from(Tuple& result, TupleV const& vals)
{
return fold(result, vals, [](auto const& init, auto&& r) {
if constexpr (is_static<remove_cvref_t<decltype(r)>>::value) { // Skip static elements of result
return init;
} else if constexpr (is_tuple<remove_cvref_t<decltype(r)>>::value) { // Recurse into tuples
return fill_int_tuple_from(r, init);
} else { // Assign and consume arg
static_assert(tuple_size<remove_cvref_t<decltype(init)>>::value > 0, "Not enough values to fill with!");
r = get<0>(init);
return remove<0>(init);
}
CUTE_GCC_UNREACHABLE;
});
}
/** Make a "Tuple" by filling in the dynamic values in order from the arguments
* \code
* using result_t = cute::tuple<Int<1>, cute::tuple<int, int, Int<3>>, int, Int<2>>;
* auto result = make_int_tuple_from<result_t>(6,3,4); // (_1,(6,3,_3),4,_2)
* \endcode
*/
template <class Tuple, class... Ts>
CUTE_HOST_DEVICE constexpr
Tuple
make_int_tuple_from(Ts const&... ts)
{
Tuple result = Tuple{};
fill_int_tuple_from(result, cute::make_tuple(ts...));
return result;
}
/** Convert a tuple to a flat homogeneous array of type T
* \code
* auto tup = cute::make_tuple(Int<1>{}, cute::make_tuple(6,3,Int<3>{}),4,Int<2>{});
* cute::array<uint64_t,6> result = to_array<uint64_t>(tup); // [1,6,3,3,4,2]
* \endcode
*/
template <class T = int64_t, class IntTuple>
CUTE_HOST_DEVICE constexpr
auto
to_array(IntTuple const& t)
{
auto flat_t = flatten_to_tuple(t);
constexpr int N = tuple_size<decltype(flat_t)>::value;
cute::array<T,N> result;
for_each(make_seq<N>{}, [&] (auto i) { result[i] = get<i>(flat_t); });
return result;
}
//
// Comparison operators
//
//
// There are many ways to compare tuple of elements and because CuTe is built
// on parameterizing layouts of coordinates, some comparisons are appropriate
// only in certain cases.
// -- lexicographical comparison [reverse, reflected, revref] : Correct for coords in RowMajor Layout
// -- colexicographical comparison [reverse, reflected, revref] : Correct for coords in ColMajor Layout
// -- element-wise comparison [any,all] :
// This can be very confusing. To avoid errors in selecting the appropriate
// comparison, op<|op<=|op>|op>= are *not* implemented for cute::tuple.
//
// When actually desiring to order coordinates, the user should map them to
// their indices within the Layout they came from:
// e.g. layoutX(coordA) < layoutX(coordB)
// That said, we implement the three most common ways to compare tuples below.
// These are implemented with slighly more explicit names than op<.
//
template <class IntTupleA, class IntTupleB>
CUTE_HOST_DEVICE constexpr
auto
lex_less(IntTupleA const& a, IntTupleB const& b);
template <class IntTupleA, class IntTupleB>
CUTE_HOST_DEVICE constexpr
auto
colex_less(IntTupleA const& a, IntTupleB const& b);
template <class IntTupleA, class IntTupleB>
CUTE_HOST_DEVICE constexpr
auto
elem_less(IntTupleA const& a, IntTupleB const& b);
namespace detail {
template <size_t I, class TupleA, class TupleB>
CUTE_HOST_DEVICE constexpr
auto
lex_less_impl(TupleA const& a, TupleB const& b)
{
if constexpr (I == tuple_size<TupleB>::value) {
return cute::false_type{}; // Terminal: TupleB is exhausted
} else if constexpr (I == tuple_size<TupleA>::value) {
return cute::true_type{}; // Terminal: TupleA is exhausted, TupleB is not exhausted
} else {
return lex_less(get<I>(a), get<I>(b)) || (get<I>(a) == get<I>(b) && lex_less_impl<I+1>(a,b));
}
CUTE_GCC_UNREACHABLE;
}
template <size_t I, class TupleA, class TupleB>
CUTE_HOST_DEVICE constexpr
auto
colex_less_impl(TupleA const& a, TupleB const& b)
{
if constexpr (I == tuple_size<TupleB>::value) {
return cute::false_type{}; // Terminal: TupleB is exhausted
} else if constexpr (I == tuple_size<TupleA>::value) {
return cute::true_type{}; // Terminal: TupleA is exhausted, TupleB is not exhausted
} else {
constexpr size_t A = tuple_size<TupleA>::value - 1 - I;
constexpr size_t B = tuple_size<TupleB>::value - 1 - I;
return colex_less(get<A>(a), get<B>(b)) || (get<A>(a) == get<B>(b) && colex_less_impl<I+1>(a,b));
}
CUTE_GCC_UNREACHABLE;
}
template <size_t I, class TupleA, class TupleB>
CUTE_HOST_DEVICE constexpr
auto
elem_less_impl(TupleA const& a, TupleB const& b)
{
if constexpr (I == tuple_size<TupleA>::value) {
return cute::true_type{}; // Terminal: TupleA is exhausted
} else if constexpr (I == tuple_size<TupleB>::value) {
return cute::false_type{}; // Terminal: TupleA is not exhausted, TupleB is exhausted
} else {
return elem_less(get<I>(a), get<I>(b)) && elem_less_impl<I+1>(a,b);
}
CUTE_GCC_UNREACHABLE;
}
} // end namespace detail
// Lexicographical comparison
template <class IntTupleA, class IntTupleB>
CUTE_HOST_DEVICE constexpr
auto
lex_less(IntTupleA const& a, IntTupleB const& b)
{
if constexpr (is_tuple<IntTupleA>::value && is_tuple<IntTupleB>::value) {
return detail::lex_less_impl<0>(a, b);
} else {
return a < b;
}
CUTE_GCC_UNREACHABLE;
}
template <class T, class U>
CUTE_HOST_DEVICE constexpr
auto
lex_leq(T const& t, U const& u) {
return !lex_less(u, t);
}
template <class T, class U>
CUTE_HOST_DEVICE constexpr
auto
lex_gtr(T const& t, U const& u) {
return lex_less(u, t);
}
template <class T, class U>
CUTE_HOST_DEVICE constexpr
auto
lex_geq(T const& t, U const& u) {
return !lex_less(t, u);
}
// Colexicographical comparison
template <class IntTupleA, class IntTupleB>
CUTE_HOST_DEVICE constexpr
auto
colex_less(IntTupleA const& a, IntTupleB const& b)
{
if constexpr (is_tuple<IntTupleA>::value && is_tuple<IntTupleB>::value) {
return detail::colex_less_impl<0>(a, b);
} else {
return a < b;
}
CUTE_GCC_UNREACHABLE;
}
template <class T, class U>
CUTE_HOST_DEVICE constexpr
auto
colex_leq(T const& t, U const& u) {
return !colex_less(u, t);
}
template <class T, class U>
CUTE_HOST_DEVICE constexpr
auto
colex_gtr(T const& t, U const& u) {
return colex_less(u, t);
}
template <class T, class U>
CUTE_HOST_DEVICE constexpr
auto
colex_geq(T const& t, U const& u) {
return !colex_less(t, u);
}
// Elementwise [all] comparison
template <class IntTupleA, class IntTupleB>
CUTE_HOST_DEVICE constexpr
auto
elem_less(IntTupleA const& a, IntTupleB const& b)
{
if constexpr (is_tuple<IntTupleA>::value && is_tuple<IntTupleB>::value) {
return detail::elem_less_impl<0>(a, b);
} else {
return a < b;
}
CUTE_GCC_UNREACHABLE;
}
template <class T, class U>
CUTE_HOST_DEVICE constexpr
auto
elem_leq(T const& t, U const& u) {
return !elem_less(u, t);
}
template <class T, class U>
CUTE_HOST_DEVICE constexpr
auto
elem_gtr(T const& t, U const& u) {
return elem_less(u, t);
}
template <class T, class U>
CUTE_HOST_DEVICE constexpr
auto
elem_geq(T const& t, U const& u) {
return !elem_less(t, u);
}
namespace detail {
/** Increment a (dynamic) coord lexicographically within a shape
* @pre is_congruent<Coord,Shape>::value
* \code
* auto shape = make_shape(1,2,make_shape(2,3),3);
*
* int i = 0;
* for (auto coord = repeat_like(shape, 0); back(coord) != back(shape); increment(coord, shape)) {
* std::cout << i++ << ": " << coord << std::endl;
* }
* assert(i == size(shape));
* \endcode
*/
template <int I = 0, class Coord, class Shape>
CUTE_HOST_DEVICE constexpr
void
increment(Coord& coord, Shape const& shape)
{
if constexpr (is_integral<Coord>::value) {
++coord;
} else {
increment(get<I>(coord), get<I>(shape));
if constexpr (I+1 < tuple_size<Coord>::value) {
if (back(get<I>(coord)) == back(get<I>(shape))) {
back(get<I>(coord)) = 0;
increment<I+1>(coord, shape);
}
}
}
}
} // end namespace detail
struct ForwardCoordIteratorSentinal
{};
// A forward iterator for a starting coordinate in a shape's domain, and a shape.
// The starting coordinate may be zero but need not necessarily be.
template <class Coord, class Shape>
struct ForwardCoordIterator
{
static_assert(is_congruent<Coord, Shape>::value);
CUTE_HOST_DEVICE constexpr
Coord const& operator*() const { return coord; }
CUTE_HOST_DEVICE constexpr
ForwardCoordIterator& operator++() { detail::increment(coord, shape); return *this; }
// Sentinel for the end of the implied range
CUTE_HOST_DEVICE constexpr
bool operator< (ForwardCoordIteratorSentinal const&) const { return back(coord) < back(shape); }
CUTE_HOST_DEVICE constexpr
bool operator==(ForwardCoordIteratorSentinal const&) const { return back(coord) == back(shape); }
CUTE_HOST_DEVICE constexpr
bool operator!=(ForwardCoordIteratorSentinal const&) const { return back(coord) != back(shape); }
// NOTE: These are expensive, avoid use
CUTE_HOST_DEVICE constexpr
bool operator< (ForwardCoordIterator const& other) const { return colex_less(coord, other.coord); }
CUTE_HOST_DEVICE constexpr
bool operator==(ForwardCoordIterator const& other) const { return coord == other.coord; }
CUTE_HOST_DEVICE constexpr
bool operator!=(ForwardCoordIterator const& other) const { return coord != other.coord; }
Coord coord;
Shape const& shape;
};
// A forward iterator for a coordinate that starts from a provided coordinate
template <class Shape, class Coord>
CUTE_HOST_DEVICE constexpr
auto
make_coord_iterator(Coord const& coord, Shape const& shape)
{
return ForwardCoordIterator<Coord,Shape>{coord,shape};
}
// A forward iterator for a coordinate that starts from zero
template <class Shape>
CUTE_HOST_DEVICE constexpr
auto
make_coord_iterator(Shape const& shape)
{
auto coord = repeat_like(shape, int(0));
return make_coord_iterator(coord, shape);
}
} // end namespace cute