Spaces:
Sleeping
Sleeping
/*************************************************************************************************** | |
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
* SPDX-License-Identifier: BSD-3-Clause | |
* | |
* Redistribution and use in source and binary forms, with or without | |
* modification, are permitted provided that the following conditions are met: | |
* | |
* 1. Redistributions of source code must retain the above copyright notice, this | |
* list of conditions and the following disclaimer. | |
* | |
* 2. Redistributions in binary form must reproduce the above copyright notice, | |
* this list of conditions and the following disclaimer in the documentation | |
* and/or other materials provided with the distribution. | |
* | |
* 3. Neither the name of the copyright holder nor the names of its | |
* contributors may be used to endorse or promote products derived from | |
* this software without specific prior written permission. | |
* | |
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE | |
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE | |
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | |
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR | |
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER | |
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, | |
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE | |
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |
* | |
**************************************************************************************************/ | |
// __cvta_generic_to_shared was added in Clang 14: https://reviews.llvm.org/D111665 | |
// __nvvm_get_smem_pointer added in Clang 14: https://reviews.llvm.org/D111665 | |
// ... but will not work on Windows until Clang 15: https://reviews.llvm.org/D122897 | |
// __cvta_generic_to_shared added in CUDA 11+ | |
// __nvvm_get_smem_pointer added in CUDA 10.2 | |
// Clang 14+ provides a declaration of __nvvm_get_smem_pointer, so we only need | |
// to provide one for NVCC | |
extern "C" { | |
// This NVVM intrinsic is subject to change in future versions of CUDA. | |
// Clients should not call it directly. | |
CUTE_DEVICE uint32_t __nvvm_get_smem_pointer(void*); | |
} | |
namespace cute | |
{ | |
/// CUTE helper to cast SMEM pointer to unsigned | |
CUTE_DEVICE | |
uint32_t | |
cast_smem_ptr_to_uint(void const* const ptr) | |
{ | |
// We prefer to use the new CVTA intrinsics if they are available, otherwise we will fall back to | |
// the previous internal intrinsics if they are available. | |
// | |
// This NVVM intrinsic converts an address in shared memory to a plain | |
// unsigned integer. This is necessary to pass to shared memory instructions | |
// in inline PTX. | |
// | |
// In CUDA 11 and beyond, this replaces __nvvm_get_smem_pointer() [only available in 10.2]. | |
// | |
//__device__ size_t __cvta_generic_to_shared(void* ptr); | |
/// CUTE helper to get SMEM pointer | |
return static_cast<uint32_t>(__cvta_generic_to_shared(ptr)); | |
return __nvvm_get_smem_pointer(ptr); | |
uint32_t smem_ptr; | |
asm( | |
"{ .reg .u64 smem_ptr; cvta.to.shared.u64 smem_ptr, %1; cvt.u32.u64 %0, smem_ptr; }\n" | |
: "=r"(smem_ptr) : "l"(ptr)); | |
return smem_ptr; | |
(void) ptr; | |
printf("ERROR: cast_smem_ptr_to_uint not supported but used.\n"); | |
return 0; | |
} | |
namespace detail { | |
// | |
// Wrapper for MMAOp::fma | |
// | |
template <class MmaOp> | |
struct CallFMA { | |
template <class... Args> | |
CUTE_HOST_DEVICE constexpr void | |
operator()(Args&&... args) const { | |
return MmaOp::fma(static_cast<Args&&>(args)...); | |
} | |
}; | |
// | |
// Wrapper for CopyOp::copy | |
// | |
template <class CopyOp> | |
struct CallCOPY { | |
template <class... Args> | |
CUTE_HOST_DEVICE constexpr void | |
operator()(Args&&... args) const { | |
return CopyOp::copy(static_cast<Args&&>(args)...); | |
} | |
}; | |
// | |
// Utility for exploding pointers/arrays/tensors into functions | |
// | |
template <class Fn, | |
class PtrA, int... I> | |
CUTE_HOST_DEVICE constexpr | |
void | |
explode(Fn fn, | |
PtrA&& a, int_sequence<I...>) | |
{ | |
return fn(a[I]...); | |
} | |
template <class Fn, | |
class PtrS, int... Is, | |
class PtrD, int... Id> | |
CUTE_HOST_DEVICE constexpr | |
void | |
explode(Fn fn, | |
PtrS&& s, int_sequence<Is...>, | |
PtrD&& d, int_sequence<Id...>) | |
{ | |
return fn(s[Is]..., d[Id]...); | |
} | |
template <class Fn, | |
class PtrA, int... Ia, | |
class PtrB, int... Ib, | |
class PtrC, int... Ic> | |
CUTE_HOST_DEVICE constexpr | |
void | |
explode(Fn fn, | |
PtrA&& a, int_sequence<Ia...>, | |
PtrB&& b, int_sequence<Ib...>, | |
PtrC&& c, int_sequence<Ic...>) | |
{ | |
return fn(a[Ia]..., b[Ib]..., c[Ic]...); | |
} | |
template <class Fn, | |
class PtrD, int... Id, | |
class PtrA, int... Ia, | |
class PtrB, int... Ib, | |
class PtrC, int... Ic> | |
CUTE_HOST_DEVICE constexpr | |
void | |
explode(Fn fn, | |
PtrD&& d, int_sequence<Id...>, | |
PtrA&& a, int_sequence<Ia...>, | |
PtrB&& b, int_sequence<Ib...>, | |
PtrC&& c, int_sequence<Ic...>) | |
{ | |
return fn(d[Id]..., a[Ia]..., b[Ib]..., c[Ic]...); | |
} | |
template <class Fn, | |
class PtrD, int... Id, | |
class PtrA, int... Ia, | |
class PtrB, int... Ib, | |
class PtrC, int... Ic, | |
class PtrE, int... Ie> | |
CUTE_HOST_DEVICE constexpr | |
void | |
explode(Fn fn, | |
PtrD&& d, int_sequence<Id...>, | |
PtrA&& a, int_sequence<Ia...>, | |
PtrB&& b, int_sequence<Ib...>, | |
PtrC&& c, int_sequence<Ic...>, | |
PtrE&& e, int_sequence<Ie...>) | |
{ | |
return fn(d[Id]..., a[Ia]..., b[Ib]..., c[Ic]..., e[Ie]...); | |
} | |
template <class Fn, | |
class PtrD, int... Id, | |
class PtrA, int... Ia, | |
class PtrB, int... Ib, | |
class PtrC, int... Ic, | |
class PtrSFA, int... Isfa, | |
class PtrSFB, int... Isfb> | |
CUTE_HOST_DEVICE constexpr | |
void | |
explode(Fn fn, | |
PtrD&& d, int_sequence<Id...>, | |
PtrA&& a, int_sequence<Ia...>, | |
PtrB&& b, int_sequence<Ib...>, | |
PtrC&& c, int_sequence<Ic...>, | |
PtrSFA&& sfa, int_sequence<Isfa...>, | |
PtrSFB&& sfb, int_sequence<Isfb...>) | |
{ | |
return fn(d[Id]..., a[Ia]..., b[Ib]..., c[Ic]..., sfa[Isfa]..., sfb[Isfb]...); | |
} | |
// | |
// Utility for exploding tuples into functions | |
// | |
template <class Fn, | |
class TupleA, int... I> | |
CUTE_HOST_DEVICE constexpr | |
void | |
explode_tuple(Fn fn, | |
TupleA&& a, int_sequence<I...>) | |
{ | |
return fn(get<I>(a)...); | |
} | |
template <class Fn, | |
class TupleA, int... Ia, | |
class TupleB, int... Ib> | |
CUTE_HOST_DEVICE constexpr | |
void | |
explode_tuple(Fn fn, | |
TupleA&& a, int_sequence<Ia...>, | |
TupleB&& b, int_sequence<Ib...>) | |
{ | |
return fn(get<Ia>(a)..., get<Ib>(b)...); | |
} | |
template <class Fn, | |
class TupleA, int... Ia, | |
class TupleB, int... Ib, | |
class TupleC, int... Ic> | |
CUTE_HOST_DEVICE constexpr | |
void | |
explode_tuple(Fn fn, | |
TupleA&& a, int_sequence<Ia...>, | |
TupleB&& b, int_sequence<Ib...>, | |
TupleC&& c, int_sequence<Ic...>) | |
{ | |
return fn(get<Ia>(a)..., get<Ib>(b)..., get<Ic>(c)...); | |
} | |
} // end namespace detail | |
} // end namespace cute | |