Spaces:
Sleeping
Sleeping
/*************************************************************************************************** | |
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
* SPDX-License-Identifier: BSD-3-Clause | |
* | |
* Redistribution and use in source and binary forms, with or without | |
* modification, are permitted provided that the following conditions are met: | |
* | |
* 1. Redistributions of source code must retain the above copyright notice, this | |
* list of conditions and the following disclaimer. | |
* | |
* 2. Redistributions in binary form must reproduce the above copyright notice, | |
* this list of conditions and the following disclaimer in the documentation | |
* and/or other materials provided with the distribution. | |
* | |
* 3. Neither the name of the copyright holder nor the names of its | |
* contributors may be used to endorse or promote products derived from | |
* this software without specific prior written permission. | |
* | |
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE | |
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE | |
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | |
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR | |
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER | |
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, | |
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE | |
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |
* | |
**************************************************************************************************/ | |
/* This implements a swizzle pointer of the form | |
* InvolutionFn o PtrAdd | |
* where the InvolutionFn need not be linear. | |
* | |
* This differs subtly from swizzle_layout because the smem pointer is used | |
* as the offset. That means that swizzle_layout will implement position-independent | |
* swizzle layouts, while swizzle_ptr implements position-dependent swizzle tensors. | |
* Arch chose to design hardware with position-dependent swizzles. | |
* | |
* For clarity: | |
* NormalLayout : DeRef <- PtrAdd <- [Layout] | |
* ComposedLayout: DeRef <- PtrAdd <- [Swizzle <- OffsetAdd <- Layout] | |
* SwizzlePtr : [DeRef <- Swizzle <- PtrAdd] <- Layout | |
* | |
* Furthermore, for known swizzles, this pointer attempts to decay itself | |
* to a normal-pointer with a new layout containing dynamic or static strides. | |
* This is possible by determining the subdomain of the InvolutionFn | |
* that is identity and testing if the Layout's codomain is contained | |
* within it. | |
*/ | |
namespace cute | |
{ | |
// concept SwizzleFn { | |
// CUTE_HOST_DEVICE constexpr static uint apply(uint); | |
// } | |
// See Swizzle<B,M,S> in swizzle.hpp for common swizzle-functions. | |
template <class SwizzleFn, class Iterator> | |
struct swizzle_ptr : iter_adaptor<Iterator,swizzle_ptr<SwizzleFn,Iterator>> | |
{ | |
using iterator = Iterator; | |
using reference = typename iterator_traits<iterator>::reference; | |
using element_type = typename iterator_traits<iterator>::element_type; | |
using value_type = typename iterator_traits<iterator>::value_type; | |
using iter_adaptor<Iterator,swizzle_ptr<SwizzleFn,Iterator>>::iter_adaptor; | |
template <class Iter> | |
CUTE_HOST_DEVICE constexpr static | |
Iter apply_swizzle(Iter ptr) { | |
return {apply_swizzle(ptr.get())}; | |
} | |
template <class T> | |
CUTE_HOST_DEVICE constexpr static | |
T* apply_swizzle(T* ptr) { | |
return reinterpret_cast<T*>(SwizzleFn::apply(reinterpret_cast<uintptr_t>(ptr))); | |
} | |
template <class T> | |
CUTE_HOST_DEVICE constexpr static | |
subbyte_iterator<T> apply_swizzle(subbyte_iterator<T> ptr) { | |
return {apply_swizzle(ptr.ptr_), ptr.idx_}; | |
} | |
CUTE_HOST_DEVICE constexpr | |
reference operator*() const { | |
return *apply_swizzle(this->get()); | |
} | |
template <class Int> | |
CUTE_HOST_DEVICE constexpr | |
reference operator[](Int const& i) const { | |
return *apply_swizzle(this->get() + i); | |
} | |
}; | |
template <class T, class = void> // Default No-Swizzle | |
struct get_swizzle { using type = Swizzle<0,4,3>; }; | |
template <class SwizzleFn, class P> // Found the SwizzleFn | |
struct get_swizzle<swizzle_ptr<SwizzleFn,P>> { using type = SwizzleFn; }; | |
template <class T> // Recurse into anything with a ::iterator | |
struct get_swizzle<T, void_t<typename T::iterator>> : get_swizzle<typename T::iterator> {}; | |
template <class Iter> | |
using get_swizzle_t = typename get_swizzle<Iter>::type; | |
template <class Iterator, class SwizzleFn> | |
CUTE_HOST_DEVICE constexpr | |
swizzle_ptr<SwizzleFn,Iterator> | |
make_swizzle_ptr(Iterator ptr, SwizzleFn) { | |
return {ptr}; | |
} | |
// Swizzle-0 specialization for immediate decay | |
template <class Iterator, int M, int S> | |
CUTE_HOST_DEVICE constexpr | |
Iterator | |
make_swizzle_ptr(Iterator ptr, Swizzle<0,M,S>) { | |
return ptr; | |
} | |
// | |
// Recast | |
// | |
template <class SwizzleFn, class P> | |
CUTE_HOST_DEVICE constexpr | |
auto | |
raw_pointer_cast(swizzle_ptr<SwizzleFn,P> const& ptr) { | |
return raw_pointer_cast(ptr.get()); | |
} | |
// SwizzleFn operates on the pointer address, so it doesn't care about the type | |
template <class NewT, class SwizzleFn, class P> | |
CUTE_HOST_DEVICE constexpr | |
auto | |
recast_ptr(swizzle_ptr<SwizzleFn,P> const& ptr) { | |
return make_swizzle_ptr(recast_ptr<NewT>(ptr.get()), SwizzleFn{}); | |
} | |
// | |
// Display utilities | |
// | |
template <class SwizzleFn, class P> | |
CUTE_HOST_DEVICE void print(swizzle_ptr<SwizzleFn,P> ptr) | |
{ | |
print(SwizzleFn{}); printf("_"); print(ptr.get()); | |
} | |
template <class SwizzleFn, class P> | |
CUTE_HOST std::ostream& operator<<(std::ostream& os, swizzle_ptr<SwizzleFn,P> ptr) | |
{ | |
return os << SwizzleFn{} << "_" << ptr.get(); | |
} | |
} // end namespace cute | |