Spaces:
Sleeping
Sleeping
/*************************************************************************************************** | |
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
* SPDX-License-Identifier: BSD-3-Clause | |
* | |
* Redistribution and use in source and binary forms, with or without | |
* modification, are permitted provided that the following conditions are met: | |
* | |
* 1. Redistributions of source code must retain the above copyright notice, this | |
* list of conditions and the following disclaimer. | |
* | |
* 2. Redistributions in binary form must reproduce the above copyright notice, | |
* this list of conditions and the following disclaimer in the documentation | |
* and/or other materials provided with the distribution. | |
* | |
* 3. Neither the name of the copyright holder nor the names of its | |
* contributors may be used to endorse or promote products derived from | |
* this software without specific prior written permission. | |
* | |
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE | |
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE | |
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | |
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR | |
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER | |
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, | |
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE | |
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |
* | |
**************************************************************************************************/ | |
namespace cute | |
{ | |
namespace detail { | |
template <class X, class = void> | |
struct supports_output_scaling { static constexpr bool value = false; }; | |
template <class X> | |
struct supports_output_scaling<X, void_t<decltype(declval<X>().accumulate_)>> { static constexpr bool value = true; }; | |
} // end namespace detail | |
/** | |
* concept MMA_Traits | |
* { | |
* using ValTypeD = // Logical A-value type | |
* using ValTypeA = // Logical B-value type | |
* using ValTypeB = // Logical C-value type | |
* using ValTypeC = // Logical D-value type (NOTE: Not used? Assumed == ValTypeD) | |
* | |
* using FrgTypeA = // A-type consumed by MMA (if ommitted, same as ValTypeA) | |
* using FrgTypeB = // B_type consumed by MMA (if ommitted, same as ValTypeB) | |
* using FrgTypeC = // C_type consumed by MMA (if ommitted, same as ValTypeC) | |
* | |
* using Shape_MNK = // Logical MxNxK shape of the MMA | |
* | |
* using ThrID = // Logical thread id (tid) -> tidx | |
* | |
* using ALayout = // (Logical thread id (tid), Logical value id (vid)) -> Flat MK-coord | |
* using BLayout = // (Logical thread id (tid), Logical value id (vid)) -> Flat NK-coord | |
* using CLayout = // (Logical thread id (tid), Logical value id (vid)) -> Flat MN-coord | |
* }; | |
*/ | |
template <class MMAOperation, class... MMAOpArgs> | |
struct MMA_Traits | |
{ | |
static_assert(sizeof(MMAOperation) == 0, "MMA_Traits not implemented for this MMA_Operation."); | |
}; | |
template <class D, class A, class B, class C> | |
struct MMA_Traits<UniversalFMA<D,A,B,C>> | |
{ | |
using ValTypeD = D; | |
using ValTypeA = A; | |
using ValTypeB = B; | |
using ValTypeC = C; | |
// Logical shape of the MMA | |
using Shape_MNK = Shape<_1,_1,_1>; | |
// Logical thread id (tid) -> tidx | |
using ThrID = Layout<_1>; | |
// (Logical thread id (tid), Logical value id (vid)) -> coord | |
// (tid,vid) -> (m,k) | |
using ALayout = Layout<Shape<_1,_1>>; | |
// (tid,vid) -> (n,k) | |
using BLayout = Layout<Shape<_1,_1>>; | |
// (tid,vid) -> (m,n) | |
using CLayout = Layout<Shape<_1,_1>>; | |
}; | |
// | |
// Generic mma_unpack for any MMA_Traits | |
// | |
template <class MMA_Op, class... MMA_Args, | |
class TD, class DLayout, | |
class TA, class ALayout, | |
class TB, class BLayout, | |
class TC, class CLayout> | |
CUTE_HOST_DEVICE constexpr | |
void | |
mma_unpack(MMA_Traits<MMA_Op, MMA_Args...> const& traits, | |
Tensor<TD, DLayout> & D, | |
Tensor<TA, ALayout> const& A, | |
Tensor<TB, BLayout> const& B, | |
Tensor<TC, CLayout> const& C) | |
{ | |
static_assert(is_rmem<TD>::value, "Expected registers in MMA_Atom::call"); | |
static_assert(is_rmem<TA>::value, "Expected registers in MMA_Atom::call"); | |
static_assert(is_rmem<TB>::value, "Expected registers in MMA_Atom::call"); | |
static_assert(is_rmem<TC>::value, "Expected registers in MMA_Atom::call"); | |
// Register value types from the MMA_Operation register arrays | |
using RegTypeD = typename remove_extent<typename MMA_Op::DRegisters>::type; | |
using RegTypeA = typename remove_extent<typename MMA_Op::ARegisters>::type; | |
using RegTypeB = typename remove_extent<typename MMA_Op::BRegisters>::type; | |
using RegTypeC = typename remove_extent<typename MMA_Op::CRegisters>::type; | |
using MMATraits = MMA_Traits<MMA_Op, MMA_Args...>; | |
[[maybe_unused]] constexpr int RegNumD = extent<typename MMA_Op::DRegisters>::value; | |
constexpr int RegNumA = extent<typename MMA_Op::ARegisters>::value; | |
constexpr int RegNumB = extent<typename MMA_Op::BRegisters>::value; | |
constexpr int RegNumC = extent<typename MMA_Op::CRegisters>::value; | |
Tensor rA = recast<RegTypeA>(A); | |
Tensor rB = recast<RegTypeB>(B); | |
CUTE_STATIC_ASSERT_V(size(rA) == Int<RegNumA>{}); | |
CUTE_STATIC_ASSERT_V(size(rB) == Int<RegNumB>{}); | |
if constexpr (is_same<RegTypeD, void>::value) | |
{ | |
static_assert(is_same<typename TD::value_type, typename TC::value_type>::value, "GMMA C and D value_type must match."); | |
static_assert(is_same<DLayout, CLayout>::value, "GMMA C and D layouts must match."); | |
// assert((void*)&C == (void*)&D); | |
Tensor rC = recast<RegTypeC>(D); // NOTE: D and C are same, so use mutable D | |
//CUTE_STATIC_ASSERT_V(size(rC) == Int<RegNumC>{}); | |
if constexpr (detail::supports_output_scaling<MMATraits>::value) { | |
detail::explode(MMA_Op::fma, | |
rA, make_int_sequence<RegNumA>{}, | |
rB, make_int_sequence<RegNumB>{}, | |
rC, make_int_sequence<RegNumC>{}, | |
&(traits.accumulate_), seq<0>{}); | |
} | |
else { | |
detail::explode(MMA_Op::fma, | |
rA, make_int_sequence<RegNumA>{}, | |
rB, make_int_sequence<RegNumB>{}, | |
rC, make_int_sequence<RegNumC>{}); | |
} | |
} | |
else { | |
Tensor rD = recast<RegTypeD>(D); | |
Tensor rC = recast<RegTypeC>(C); | |
CUTE_STATIC_ASSERT_V(size(rD) == Int<RegNumD>{}); | |
CUTE_STATIC_ASSERT_V(size(rC) == Int<RegNumC>{}); | |
if constexpr (detail::supports_output_scaling<MMATraits>::value) { | |
detail::explode(MMA_Op::fma, | |
rD, make_int_sequence<RegNumD>{}, | |
rA, make_int_sequence<RegNumA>{}, | |
rB, make_int_sequence<RegNumB>{}, | |
rC, make_int_sequence<RegNumC>{}, | |
&(traits.accumulate_), seq<0>{}); | |
} | |
else { | |
detail::explode(MMA_Op::fma, | |
rD, make_int_sequence<RegNumD>{}, | |
rA, make_int_sequence<RegNumA>{}, | |
rB, make_int_sequence<RegNumB>{}, | |
rC, make_int_sequence<RegNumC>{}); | |
} | |
} | |
} | |
// | |
// Accept mutable temporaries | |
// | |
template <class MMA_Op, class... MMA_Args, | |
class TD, class DLayout, | |
class TA, class ALayout, | |
class TB, class BLayout, | |
class TC, class CLayout> | |
CUTE_HOST_DEVICE constexpr | |
void | |
mma_unpack(MMA_Traits<MMA_Op, MMA_Args...> const& traits, | |
Tensor<TD, DLayout> && D, | |
Tensor<TA, ALayout> const& A, | |
Tensor<TB, BLayout> const& B, | |
Tensor<TC, CLayout> const& C) | |
{ | |
mma_unpack(traits, D, A, B, C); | |
} | |
namespace detail { | |
template <class X, class = void> | |
struct FrgTypeA_or_Default { using type = typename X::ValTypeA; }; | |
template <class X> | |
struct FrgTypeA_or_Default<X,void_t<typename X::FrgTypeA>> { using type = typename X::FrgTypeA; }; | |
template <class X, class = void> | |
struct FrgTypeB_or_Default { using type = typename X::ValTypeB; }; | |
template <class X> | |
struct FrgTypeB_or_Default<X,void_t<typename X::FrgTypeB>> { using type = typename X::FrgTypeB; }; | |
template <class X, class = void> | |
struct FrgTypeC_or_Default { using type = typename X::ValTypeC; }; | |
template <class X> | |
struct FrgTypeC_or_Default<X,void_t<typename X::FrgTypeC>> { using type = typename X::FrgTypeC; }; | |
} // end namespace detail | |
} // namespace cute | |