Spaces:
Sleeping
Sleeping
/*************************************************************************************************** | |
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
* SPDX-License-Identifier: BSD-3-Clause | |
* | |
* Redistribution and use in source and binary forms, with or without | |
* modification, are permitted provided that the following conditions are met: | |
* | |
* 1. Redistributions of source code must retain the above copyright notice, this | |
* list of conditions and the following disclaimer. | |
* | |
* 2. Redistributions in binary form must reproduce the above copyright notice, | |
* this list of conditions and the following disclaimer in the documentation | |
* and/or other materials provided with the distribution. | |
* | |
* 3. Neither the name of the copyright holder nor the names of its | |
* contributors may be used to endorse or promote products derived from | |
* this software without specific prior written permission. | |
* | |
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE | |
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE | |
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | |
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR | |
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER | |
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, | |
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE | |
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |
* | |
**************************************************************************************************/ | |
namespace cute | |
{ | |
// | |
// Stand-in Swizzle Layout | |
// A model of a nullptr smem_ptr<T> with B == sizeof_bits<T>::value | |
// That represents an unset pointer. This is a placeholder type that is waiting for an smem_ptr | |
// | |
template <int Bits> | |
struct smem_ptr_flag_bits : Int<0> {}; | |
using smem_ptr_flag = smem_ptr_flag_bits<1>; | |
// A flagged construction method to transform ComposedLayout | |
// Make a swizzle pointer tensor and check that the intended type size matches | |
template <class Iterator, class SwizzleFn, int B, class Layout> | |
CUTE_HOST_DEVICE constexpr | |
auto | |
make_tensor(Iterator const& ptr, | |
ComposedLayout<SwizzleFn,smem_ptr_flag_bits<B>,Layout> const& layout) | |
{ | |
static_assert(is_smem<Iterator>::value, "Expected smem."); | |
static_assert(B == sizeof_bits<iter_value_t<Iterator>>::value, "Expected a B-bit pointer type."); | |
return make_tensor(make_smem_ptr(ptr.get(), layout.layout_a()), | |
layout.layout_b()); | |
} | |
// NOTE: To preserve smem_ptr_flag_bits under recast ops | |
template <int N, class SwizzleFn, int B, class Layout> | |
CUTE_HOST_DEVICE constexpr | |
auto | |
upcast(ComposedLayout<SwizzleFn,smem_ptr_flag_bits<B>,Layout> const& layout) | |
{ | |
return composition(layout.layout_a(), smem_ptr_flag_bits<B*N>{}, upcast<N>(layout.layout_b())); | |
} | |
template <int N, class SwizzleFn, int B, class Layout> | |
CUTE_HOST_DEVICE constexpr | |
auto | |
downcast(ComposedLayout<SwizzleFn,smem_ptr_flag_bits<B>,Layout> const& layout) | |
{ | |
return composition(layout.layout_a(), smem_ptr_flag_bits<B/N>{}, downcast<N>(layout.layout_b())); | |
} | |
// | |
// Conversion with swizzle_layout | |
// | |
template <class SwizzleFn, int B, class Layout> | |
CUTE_HOST_DEVICE | |
auto | |
as_position_independent_swizzle_layout(ComposedLayout<SwizzleFn,smem_ptr_flag_bits<B>,Layout> const& layout) | |
{ | |
return composition(recast_layout<uint8_t,uint_bit_t<B>>(layout.layout_a()), Int<0>{}, layout.layout_b()); | |
} | |
template <class Tensor> | |
CUTE_HOST_DEVICE | |
auto | |
as_position_independent_swizzle_tensor(Tensor&& tensor) | |
{ | |
static_assert(is_smem<remove_cvref_t<Tensor>>::value, "Expected smem tensor."); | |
using SwizzleFn = get_swizzle_t<remove_cvref_t<Tensor>>; | |
if constexpr (SwizzleFn::num_bits == 0) { | |
return tensor; | |
} else { | |
{ | |
uint32_t address = cast_smem_ptr_to_uint(raw_pointer_cast(static_cast<Tensor&&>(tensor).data())); | |
uint32_t mask = ((uint32_t(1) << SwizzleFn::num_base) - 1) | SwizzleFn::swizzle_code; | |
assert((address & mask) == 0); // Alignment to the Base, Z, and Y of Swizzle | |
} | |
using T = typename remove_cvref_t<Tensor>::value_type; | |
// Recast swizzle from acting on byte-addressed pointers to elements of type-T | |
auto new_swizzle = recast_layout<uint8_t, T>(SwizzleFn{}); | |
// Strip off everything and create a new smem_ptr for type-T | |
auto new_ptr = make_smem_ptr<T>(raw_pointer_cast(static_cast<Tensor&&>(tensor).data())); | |
return make_tensor(new_ptr, composition(new_swizzle, Int<0>{}, tensor.layout())); | |
} | |
CUTE_GCC_UNREACHABLE; | |
} | |
// | |
// Display utilities | |
// | |
// Capture and cast smem_ptr_flag Layouts to offset-0 layouts | |
template <class SwizzleFn, int B, class Layout> | |
CUTE_HOST_DEVICE | |
void | |
print_layout(ComposedLayout<SwizzleFn,smem_ptr_flag_bits<B>,Layout> const& layout) | |
{ | |
print_layout(as_position_independent_swizzle_layout(layout)); | |
} | |
template <class SwizzleFn, int B, class Layout> | |
CUTE_HOST_DEVICE | |
void | |
print_latex(ComposedLayout<SwizzleFn,smem_ptr_flag_bits<B>,Layout> const& layout) | |
{ | |
print_latex(as_position_independent_swizzle_layout(layout)); | |
} | |
template <int B> | |
CUTE_HOST_DEVICE void print(smem_ptr_flag_bits<B> ptr) | |
{ | |
printf("smem_ptr[%db](unset)", B); | |
} | |
} // end namespace cute | |