Spaces:
Sleeping
Sleeping
/*************************************************************************************************** | |
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
* SPDX-License-Identifier: BSD-3-Clause | |
* | |
* Redistribution and use in source and binary forms, with or without | |
* modification, are permitted provided that the following conditions are met: | |
* | |
* 1. Redistributions of source code must retain the above copyright notice, this | |
* list of conditions and the following disclaimer. | |
* | |
* 2. Redistributions in binary form must reproduce the above copyright notice, | |
* this list of conditions and the following disclaimer in the documentation | |
* and/or other materials provided with the distribution. | |
* | |
* 3. Neither the name of the copyright holder nor the names of its | |
* contributors may be used to endorse or promote products derived from | |
* this software without specific prior written permission. | |
* | |
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE | |
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE | |
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | |
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR | |
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER | |
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, | |
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE | |
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |
* | |
**************************************************************************************************/ | |
namespace cute | |
{ | |
template <class... Args> | |
struct Copy_Atom; | |
template <class CopyOperation, class CopyInternalType> | |
struct Copy_Atom<CopyOperation, CopyInternalType> : Copy_Atom<Copy_Traits<CopyOperation>, CopyInternalType> | |
{}; | |
template <class... Args, class CopyInternalType> | |
struct Copy_Atom<Copy_Traits<Args...>, CopyInternalType> | |
: Copy_Traits<Args...> | |
{ | |
using Traits = Copy_Traits<Args...>; | |
// Bit and Thr layouts from the Copy_Traits | |
using ThrID = typename Traits::ThrID; | |
using BitLayoutSrc = typename Traits::SrcLayout; | |
using BitLayoutDst = typename Traits::DstLayout; | |
using BitLayoutRef = typename Traits::RefLayout; | |
using ValType = CopyInternalType; | |
using ValLayoutSrc = decltype(recast_layout<uint1_t, ValType>(BitLayoutSrc{})); | |
using ValLayoutDst = decltype(recast_layout<uint1_t, ValType>(BitLayoutDst{})); | |
using ValLayoutRef = decltype(recast_layout<uint1_t, ValType>(BitLayoutRef{})); | |
CUTE_STATIC_ASSERT_V(size<0>(ValLayoutSrc{}) == size(ThrID{}), "CopyOperation is not valid for Src of ValType."); | |
CUTE_STATIC_ASSERT_V(size<0>(ValLayoutDst{}) == size(ThrID{}), "CopyOperation is not valid for Dst of ValType."); | |
CUTE_STATIC_ASSERT_V(size<0>(ValLayoutRef{}) == size(ThrID{}), "CopyOperation is not valid for Ref of ValType."); | |
static constexpr int NumValSrc = size<1>(ValLayoutSrc{}); | |
static constexpr int NumValDst = size<1>(ValLayoutDst{}); | |
// Additional Trait parameters/transformations | |
template <class... TraitsArgs> | |
CUTE_HOST_DEVICE | |
auto | |
with(TraitsArgs&&... args) const { | |
auto traits = Traits::with(static_cast<TraitsArgs&&>(args)...); | |
return Copy_Atom<decltype(traits), CopyInternalType>{traits}; | |
} | |
// | |
// Tensor call interfaces | |
// | |
// Check and call instruction, or recurse | |
template <class SEngine, class SLayout, | |
class DEngine, class DLayout> | |
CUTE_HOST_DEVICE | |
void | |
call(Tensor<SEngine,SLayout> const& src, | |
Tensor<DEngine,DLayout> & dst) const | |
{ | |
static_assert(SLayout::rank == 1, "Expected rank-1 src tensor"); | |
static_assert(DLayout::rank == 1, "Expected rank-1 dst tensor"); | |
if constexpr (is_constant<NumValSrc, decltype(size(src))>::value || | |
is_constant<NumValDst, decltype(size(dst))>::value) { | |
// Dispatch to unpack to execute instruction | |
return copy_unpack(*this, src, dst); | |
} else | |
if constexpr (is_tuple<decltype(shape(src))>::value && | |
is_tuple<decltype(shape(dst))>::value) { | |
// If the size of the src/dst doesn't match the instruction, | |
// recurse this rank-1 layout by peeling off the mode | |
// ((A,B,C,...)) -> (A,B,C,...) | |
return copy(*this, tensor<0>(src), tensor<0>(dst)); | |
} else { | |
static_assert(dependent_false<SEngine>, "No instruction match and no recursion possible."); | |
} | |
} | |
// Accept mutable temporaries | |
template <class SEngine, class SLayout, | |
class DEngine, class DLayout> | |
CUTE_HOST_DEVICE | |
void | |
call(Tensor<SEngine,SLayout> const& src, | |
Tensor<DEngine,DLayout> && dst) const | |
{ | |
return call(src, dst); | |
} | |
}; | |
// | |
// A tiling of copy atoms | |
// | |
template <class TiledCopy, class ThrIdx> | |
struct ThrCopy; | |
template <class Copy_Atom, | |
class LayoutCopy_TV, // (tid,vid) -> coord [Need not be 2D...] | |
class ShapeTiler_MN> // coord space | |
struct TiledCopy : Copy_Atom | |
{ | |
// Layout information from the CopyAtom | |
using AtomThrID = typename Copy_Atom::ThrID; // thrid -> thr_idx | |
using AtomLayoutSrc = typename Copy_Atom::ValLayoutSrc; // (thr,val) -> offset | |
using AtomLayoutDst = typename Copy_Atom::ValLayoutDst; // (thr,val) -> offset | |
using AtomLayoutRef = typename Copy_Atom::ValLayoutRef; // (thr,val) -> offset | |
using AtomNumThr = decltype(size<0>(AtomLayoutRef{})); | |
using AtomNumVal = decltype(size<1>(AtomLayoutRef{})); | |
// Layout information for the TiledCopy | |
using Tiler_MN = ShapeTiler_MN; | |
using TiledLayout_TV = LayoutCopy_TV; | |
using TiledNumThr = decltype(size<0>(TiledLayout_TV{})); | |
using TiledNumVal = decltype(size<1>(TiledLayout_TV{})); | |
CUTE_STATIC_ASSERT_V(TiledNumThr{} % AtomNumThr{} == Int<0>{}, "TiledCopy uses too few thrs for selected CopyAtom"); | |
CUTE_STATIC_ASSERT_V(TiledNumVal{} % AtomNumVal{} == Int<0>{}, "TiledCopy uses too few vals for selected CopyAtom"); | |
// Tile a tensor or a layout from shape | |
// (M,N,...) | |
// to shape | |
// ((ThrV,ThrX),FrgV,(RestM,RestN,...)) | |
// where | |
// ThrV: The threads local to a COPY_ATOM Src. | |
// ThrX: The threads tiled across COPY_ATOMs Src. | |
// FrgV: The values local to a COPY_ATOM Src. | |
// RestM: The values tiled in M. | |
// RestN: The values tiled in N. | |
template <class STensor> | |
CUTE_HOST_DEVICE constexpr static | |
auto | |
tidfrg_S(STensor&& stensor) | |
{ | |
CUTE_STATIC_ASSERT_V(rank(stensor) >= rank(Tiler_MN{}), "Rank of tensor to be partitioned too small."); | |
// Tile the stensor and compute the (src-thr, src-val) -> (ref-thr, ref-val) layout | |
return tile2thrfrg(zipped_divide(stensor,Tiler_MN{}), right_inverse(AtomLayoutRef{}).compose(AtomLayoutSrc{})); | |
} | |
// Tile a tensor or a layout from shape | |
// (M,N,...) | |
// to shape | |
// ((ThrV,ThrX),FrgV,(RestM,RestN,...)) | |
// where | |
// ThrV: The threads local to a COPY_ATOM Dst. | |
// ThrX: The threads tiled across COPY_ATOMs Dst. | |
// FrgV: The values local to a COPY_ATOM Dst. | |
// RestM: The values tiled in M. | |
// RestN: The values tiled in N. | |
template <class DTensor> | |
CUTE_HOST_DEVICE constexpr static | |
auto | |
tidfrg_D(DTensor&& dtensor) | |
{ | |
CUTE_STATIC_ASSERT_V(rank(dtensor) >= rank(Tiler_MN{}), "Rank of tensor to be partitioned too small."); | |
// Tile the dtensor and compute the (dst-thr, dst-val) -> (ref-thr, ref-val) layout | |
return tile2thrfrg(zipped_divide(dtensor,Tiler_MN{}), right_inverse(AtomLayoutRef{}).compose(AtomLayoutDst{})); | |
} | |
// Tile a tensor or a layout from shape | |
// ((TileM,TileN,...), (RestM,RestN,...)) | |
// to shape | |
// ((ThrV,ThrX),FrgV,(RestM,RestN,...)) | |
template <class Tensor, class Ref2TrgLayout> | |
CUTE_HOST_DEVICE constexpr static | |
auto | |
tile2thrfrg(Tensor&& tensor, Ref2TrgLayout const& ref2trg) | |
{ | |
// Take the thrs/vals that the atom is interested in | |
// NOTE: Assumes the AtomNumThr are contiguous and identity within TiledThrID | |
auto atom_layout_TV = zipped_divide(TiledLayout_TV{}, make_shape(AtomNumThr{}, AtomNumVal{})); | |
// ((atom_tid,atom_val),(rest_tid,rest_val)) -> (m,n) | |
// Transform to the trg layout | |
auto trg_layout_TV = atom_layout_TV.compose(ref2trg, _); | |
// ((trg_tid,trg_val),(rest_tid,rest_val)) -> (m,n) | |
// Transform the thrs mode from thrid to thr_idx | |
// NOTE: Assumes the AtomNumThr are contiguous and identity within TiledThrID | |
auto thrval2mn = coalesce(zip(trg_layout_TV), Shape<_1,Shape<_1,_1>>{}); | |
// ((trg_tid,rest_tid),(trg_val,rest_val)) -> (m,n) | |
/// ================== | |
// Transform the tile mode | |
auto tv_tensor = tensor.compose(thrval2mn, _); | |
// ((thrid,val),(RestM,RestN,...)) | |
// Unfold and return | |
return tv_tensor(make_coord(_,_), _); | |
} | |
// retile_S and retile_D assume they are working with the reference layout -- they are the same | |
template <class Tensor> | |
CUTE_HOST_DEVICE constexpr static | |
auto | |
retile(Tensor&& tensor) | |
{ | |
constexpr int R = remove_cvref_t<Tensor>::rank; | |
// Assert that AtomLayoutSrc|Dst is identity so we can skip the Ref transformation | |
// Assume the first size<0>(tensor) elements are the first val_ids in TiledLayout_TV. | |
// Then, we only need the shape+layout of those size<0>(tensor) elements in TiledLayout_TV | |
// and that shape is what we gather from the other modes of tensor | |
auto V = size<0>(tensor); | |
auto frg_layout_mn = upcast<TiledNumThr{} * V>(right_inverse(TiledLayout_TV{}).with_shape(shape(Tiler_MN{}))); | |
// (m,n) -> v_idx -- The shape and order of the V inside of TiledLayout_TV | |
auto frg_layout_v = zipped_divide(logical_product(make_layout(V), right_inverse(frg_layout_mn)), make_layout(AtomNumVal{})); | |
// (atom_vals,rest_vals) -> (v,m,n) | |
/// ======= | |
// Tile the tensor for TileFrg | |
auto t_tensor = zipped_divide(tensor, prepend(product_each(shape(frg_layout_mn)), V)); | |
// ((TileV,TileM,TileN,...),(1,RestM,RestN,...)) | |
// Transform the tile mode | |
auto v_tensor = t_tensor.compose(frg_layout_v, _); | |
// ((atom_vals,rest_vals),(1,RM,RN,...)) | |
// Unfold and return | |
return v_tensor(_, append<R>(Int<0>{},_)); | |
} | |
CUTE_HOST_DEVICE constexpr static | |
auto | |
get_layoutS_TV() | |
{ | |
// (M,N) -> (M,N) | |
auto ref_S = make_layout(make_shape(shape(Tiler_MN{}), Int<1>{})); | |
// (thr_idx,val_idx) -> (M,N) | |
return tile2thrfrg(ref_S, right_inverse(AtomLayoutRef{}).compose(AtomLayoutSrc{}))(_,_,Int<0>{}); | |
} | |
CUTE_HOST_DEVICE constexpr static | |
auto | |
get_layoutS_MN() | |
{ | |
// (thr_idx,val_idx) -> (M,N) | |
auto layoutS_TV = get_layoutS_TV(); | |
// (M,K) -> (thr_idx,val_idx) | |
auto layoutS_MK = right_inverse(layoutS_TV).with_shape(shape(Tiler_MN{})); | |
// athrid = (v,m,k) -> thr_idx | |
auto thrID_S = make_layout(size<0>(TiledLayout_TV{})); | |
return cute::make_tuple(layoutS_MK, thrID_S); | |
} | |
CUTE_HOST_DEVICE constexpr static | |
auto | |
get_layoutD_TV() | |
{ | |
// (M,N) -> (M,N) | |
auto ref_D = make_layout(make_shape(shape(Tiler_MN{}), Int<1>{})); | |
// (thr_idx,val_idx) -> (M,N) | |
return tile2thrfrg(ref_D, right_inverse(AtomLayoutRef{}).compose(AtomLayoutDst{}))(_,_,Int<0>{}); | |
} | |
CUTE_HOST_DEVICE constexpr static | |
auto | |
get_layoutD_MN() | |
{ | |
// (thr_idx,val_idx) -> (M,N) | |
auto layoutD_TV = get_layoutD_TV(); | |
// (M,K) -> (thr_idx,val_idx) | |
auto layoutD_MK = right_inverse(layoutD_TV).with_shape(shape(Tiler_MN{})); | |
// athrid = (v,m,k) -> thr_idx | |
auto thrID_D = make_layout(size<0>(TiledLayout_TV{})); | |
return cute::make_tuple(layoutD_MK, thrID_D); | |
} | |
template <class ThrIdx, | |
__CUTE_REQUIRES(is_integral<ThrIdx>::value)> | |
CUTE_HOST_DEVICE static | |
auto | |
get_slice(ThrIdx const& thr_idx) | |
{ | |
return ThrCopy<TiledCopy, ThrIdx>(thr_idx); | |
} | |
template <class ThrIdx, | |
__CUTE_REQUIRES(is_integral<ThrIdx>::value)> | |
CUTE_HOST_DEVICE static | |
auto | |
get_thread_slice(ThrIdx const& thr_idx) | |
{ | |
return get_slice(thr_idx); | |
} | |
}; | |
template <class TiledCopy, class ThrIdx> | |
struct ThrCopy | |
{ | |
ThrIdx thr_idx_; | |
CUTE_HOST_DEVICE | |
ThrCopy(ThrIdx const& thr_idx) : thr_idx_(thr_idx) {} | |
template <class STensor> | |
CUTE_HOST_DEVICE | |
auto | |
partition_S(STensor&& stensor) const { | |
//static_assert(sizeof(typename remove_cvref_t<STensor>::value_type) == sizeof(typename TiledCopy::ValType), | |
// "Expected ValType for tiling SrcTensor."); | |
auto thr_tensor = make_tensor(static_cast<STensor&&>(stensor).data(), TiledCopy::tidfrg_S(stensor.layout())); | |
return thr_tensor(thr_idx_, _, repeat<rank_v<STensor>>(_)); | |
} | |
template <class DTensor> | |
CUTE_HOST_DEVICE | |
auto | |
partition_D(DTensor&& dtensor) const { | |
//static_assert(sizeof(typename remove_cvref_t<DTensor>::value_type) == sizeof(typename TiledCopy::ValType), | |
// "Expected ValType for tiling DstTensor."); | |
auto thr_tensor = make_tensor(static_cast<DTensor&&>(dtensor).data(), TiledCopy::tidfrg_D(dtensor.layout())); | |
return thr_tensor(thr_idx_, _, repeat<rank_v<DTensor>>(_)); | |
} | |
template <class STensor> | |
CUTE_HOST_DEVICE static | |
auto | |
retile_S(STensor&& stensor) { | |
// static_assert(sizeof(typename remove_cvref_t<STensor>::value_type) == sizeof(typename TiledCopy::ValType), | |
// "Expected ValType for tiling SrcTensor."); | |
return make_tensor(static_cast<STensor&&>(stensor).data(), TiledCopy::retile(stensor.layout())); | |
} | |
template <class DTensor> | |
CUTE_HOST_DEVICE static | |
auto | |
retile_D(DTensor&& dtensor) { | |
// static_assert(sizeof(typename remove_cvref_t<DTensor>::value_type) == sizeof(typename TiledCopy::ValType), | |
// "Expected ValType for tiling DstTensor."); | |
return make_tensor(static_cast<DTensor&&>(dtensor).data(), TiledCopy::retile(dtensor.layout())); | |
} | |
}; | |
template <class... Args, | |
class LayoutCopy_TV, | |
class Tiler> | |
CUTE_HOST_DEVICE | |
auto | |
make_tiled_copy_impl(Copy_Atom<Args...> const& atom, | |
LayoutCopy_TV const&, | |
Tiler const&) | |
{ | |
return TiledCopy<Copy_Atom<Args...>, LayoutCopy_TV, Tiler>{atom}; | |
} | |
// | |
// These tile the Copy_Atom as a whole | |
// | |
template <class... CArgs, class... MArgs> | |
CUTE_HOST_DEVICE | |
auto | |
make_tiled_copy_A(Copy_Atom<CArgs...> const& copy_atom, | |
TiledMMA<MArgs...> const& mma) | |
{ | |
return make_tiled_copy_impl(copy_atom, mma.get_layoutA_TV(), make_shape(tile_size<0>(mma),tile_size<2>(mma))); | |
} | |
template <class... CArgs, class... MArgs> | |
CUTE_HOST_DEVICE | |
auto | |
make_tiled_copy_B(Copy_Atom<CArgs...> const& copy_atom, | |
TiledMMA<MArgs...> const& mma) | |
{ | |
return make_tiled_copy_impl(copy_atom, mma.get_layoutB_TV(), make_shape(tile_size<1>(mma),tile_size<2>(mma))); | |
} | |
template <class... CArgs, class... MArgs> | |
CUTE_HOST_DEVICE | |
auto | |
make_tiled_copy_C(Copy_Atom<CArgs...> const& copy_atom, | |
TiledMMA<MArgs...> const& mma) | |
{ | |
return make_tiled_copy_impl(copy_atom, mma.get_layoutC_TV(), make_shape(tile_size<0>(mma),tile_size<1>(mma))); | |
} | |
// returns the smallest tiled copy that can retile LayoutC_TV | |
// for use with pipelined epilogues with subtiled stores | |
template <class... CArgs, class... MArgs> | |
CUTE_HOST_DEVICE | |
auto | |
make_tiled_copy_C_atom(Copy_Atom<CArgs...> const& copy_atom, | |
TiledMMA<MArgs...> const& mma) | |
{ | |
// Truncate the V-layout to just the Copy_Atom, keep the V-order | |
auto layoutC_TV = mma.get_layoutC_TV(); | |
auto copy_V = Int<Copy_Atom<CArgs...>::NumValSrc>{}; | |
CUTE_STATIC_ASSERT_V(copy_V <= size<1>(layoutC_TV)); | |
auto layout_TV = composition(layoutC_TV, make_layout(make_shape(size<0>(layoutC_TV), copy_V))); | |
// Recompute tiler and restride the TV layout for the new tiler | |
// Tiler -- Find the active elements in the MMA tensor and generate a tiler to extract them | |
// Convert to the awkward by-mode tiler to preserve the modes of the tiled MMA | |
auto mma_tiler = make_shape(tile_size<0>(mma),tile_size<1>(mma)); | |
auto mma_zeros = repeat_like(mma_tiler, Int<0>{}); | |
auto tiler = transform(make_seq<rank(mma_tiler)>{}, [&](auto i) { | |
return filter(composition(make_layout(mma_tiler, replace<i>(mma_zeros, Int<1>{})), layout_TV)); | |
}); | |
// Layout_TV -- Find the (tid,vid) -> tile coord transformation | |
// Apply the tiler to a reference and transform the codomain | |
// tile_coord -> mma_coord | |
auto tile2mma = composition(make_layout(mma_tiler), tiler); | |
// (tid,vid) -> tile_coord | |
auto layout_tv = composition(left_inverse(tile2mma), layout_TV); | |
return make_tiled_copy_impl(copy_atom, layout_tv, tiler); | |
} | |
/** Produce a TiledCopy from logical thread and values layouts. | |
* The thread and value layouts map coordinates to thr_idx and val_idx. | |
* The product of these layouts is taken to produce the TV layout and the Tiler. | |
* Useful when threads and values need very specific mappings onto coordinates | |
* in the target tensors. | |
*/ | |
template <class... Args, | |
class ThrLayout, | |
class ValLayout = Layout<_1>> | |
CUTE_HOST_DEVICE | |
auto | |
make_tiled_copy(Copy_Atom<Args...> const& copy_atom, | |
ThrLayout const& thr_layout = {}, // (m,n) -> thr_idx | |
ValLayout const& val_layout = {}) // (m,n) -> val_idx | |
{ | |
// Take the raked_products to compute the Layout_MN | |
// (M,N) -> (thr_idx, val_idx) | |
auto layout_mn = raked_product(thr_layout, val_layout); | |
// (thr_idx, val_idx) -> (M,N) | |
auto layout_tv = right_inverse(layout_mn).with_shape(make_shape(size(thr_layout), size(val_layout))); | |
// Tiler for extracting relevant elements | |
// (M,N) -> tensor coord | |
auto tiler = product_each(shape(layout_mn)); | |
print("thr_layout: "); print(thr_layout); print("\n"); | |
print("val_layout: "); print(val_layout); print("\n"); | |
print("layout_mn : "); print(layout_mn); print("\n"); | |
print("layout_tv : "); print(layout_tv); print("\n"); | |
print("tiler : "); print(tiler); print("\n"); | |
return make_tiled_copy_impl(copy_atom, layout_tv, tiler); | |
} | |
/** Produce a TiledCopy from thread and value offset maps. | |
* The TV Layout maps threads and values to the codomain of the data_layout. | |
* It is verified that the intended codomain is valid within data_layout. | |
* Useful when threads and values don't care about owning specific coordinates, but | |
* care more about the vector-width and offsets between them. | |
*/ | |
template <class... Args, class AtomTVLayout, class DataLayout> | |
CUTE_HOST_DEVICE constexpr | |
auto | |
make_cotiled_copy(Copy_Atom<Args...> const& copy_atom, | |
AtomTVLayout const& atom_tv_layout, // atom (thr,val) -> data addr | |
DataLayout const& data_layout) // coord -> data addr The target layout | |
{ | |
static_assert(is_static<AtomTVLayout>::value); | |
static_assert(is_static<DataLayout>::value); | |
// data addr -> data coord Append 1:0 so off-the-ends get the stride-0 | |
auto inv_data_layout = make_layout(left_inverse(data_layout), Layout<_1,_0>{}); | |
// (tid,vid) -> data_coord | |
auto layout_tv_data = composition(inv_data_layout, atom_tv_layout); | |
// Check validity | |
CUTE_STATIC_ASSERT_V(coalesce(composition(data_layout, layout<1>(layout_tv_data))) == coalesce(layout<1>(atom_tv_layout)), | |
"The memory pointed to by AtomTVLayout does not exist in the DataLayout."); | |
if (thread0()) { | |
print("data_layout : "); print(data_layout); print("\n"); | |
print("atom_tv_layout : "); print(atom_tv_layout); print("\n"); | |
print("layout_tv_data : "); print(layout_tv_data); print("\n"); | |
} | |
// | |
// Tiler -- Find the active elements in the DATA tensor and generate a tiler to extract them | |
// | |
// Convert to the awkward by-mode tiler to preserve the modes of the tiled DATA | |
auto flat_data_shape = product_each(shape(data_layout)); | |
auto flat_data_zeros = repeat<rank(flat_data_shape)>(Int<0>{}); | |
auto tiler = transform(make_seq<rank(flat_data_shape)>{}, [&](auto i) { | |
return filter(composition(make_layout(flat_data_shape, replace<i>(flat_data_zeros, Int<1>{})), layout_tv_data)); | |
}); | |
// | |
// Layout_TV -- Find the (tid,vid) -> tile coord transformation | |
// | |
// Apply the tiler to a reference and transform the codomain | |
// tile_coord -> data_coord | |
auto tile2data = composition(make_layout(flat_data_shape), tiler); | |
// (tid,vid) -> tile_coord | |
auto layout_tv = composition(left_inverse(tile2data), layout_tv_data); | |
if (thread0()) { | |
print("tiler : "); print(tiler); print("\n"); | |
print("tile2data : "); print(tile2data); print("\n"); | |
print("layout_tv : "); print(layout_tv); print("\n"); | |
} | |
return make_tiled_copy_impl(copy_atom, layout_tv, tiler); | |
} | |
// Make a TiledCopy out of the copy_atom that matches the Src-Layout of tiled_copy | |
template <class... Args, | |
class TiledCopy> | |
CUTE_HOST_DEVICE | |
auto | |
make_tiled_copy_S(Copy_Atom<Args...> const& copy_atom, | |
TiledCopy const& tiled_copy) | |
{ | |
return make_tiled_copy_impl(copy_atom, tiled_copy.get_layoutS_TV(), typename TiledCopy::Tiler_MN{}); | |
} | |
// Make a TiledCopy out of the copy_atom that matches the Dst-Layout of tiled_copy | |
template <class... Args, | |
class TiledCopy> | |
CUTE_HOST_DEVICE | |
auto | |
make_tiled_copy_D(Copy_Atom<Args...> const& copy_atom, | |
TiledCopy const& tiled_copy) | |
{ | |
return make_tiled_copy_impl(copy_atom, tiled_copy.get_layoutD_TV(), typename TiledCopy::Tiler_MN{}); | |
} | |
// | |
// Size | |
// | |
// The logical size of a TileCopy | |
template <int... I, class... Args> | |
CUTE_HOST_DEVICE constexpr | |
auto | |
tile_size(TiledCopy<Args...> const&) | |
{ | |
return size<I...>(typename TiledCopy<Args...>::Tiler_MN{}); | |
} | |
// The number of threads involved in a TiledCopy | |
template <class... Args> | |
CUTE_HOST_DEVICE constexpr | |
auto | |
size(TiledCopy<Args...> const&) | |
{ | |
return typename TiledCopy<Args...>::TiledNumThr{}; | |
} | |
// | |
// Display utilities | |
// | |
template <class... Args, class T> | |
CUTE_HOST_DEVICE | |
void | |
print(Copy_Atom<Copy_Traits<Args...>, T> const&) | |
{ | |
using Atom = Copy_Atom<Copy_Traits<Args...>, T>; | |
print("Copy_Atom\n"); | |
print(" ThrID: "); print(typename Atom::ThrID{}); print("\n"); | |
print(" ValLayoutSrc: "); print(typename Atom::ValLayoutSrc{}); print("\n"); | |
print(" ValLayoutDst: "); print(typename Atom::ValLayoutDst{}); print("\n"); | |
print(" ValLayoutRef: "); print(typename Atom::ValLayoutRef{}); print("\n"); | |
print(" ValueType: "); print(sizeof_bits<typename Atom::ValType>::value); print("b\n"); | |
} | |
template <class Atom, class... Args> | |
CUTE_HOST_DEVICE | |
void | |
print(TiledCopy<Atom, Args...> const& copy, char const* pad = "") | |
{ | |
using Copy = TiledCopy<Atom, Args...>; | |
print("TiledCopy\n"); | |
print(" Tiler_MN: "); print(typename Copy::Tiler_MN{}); print("\n"); | |
print(" TiledLayout_TV: "); print(typename Copy::TiledLayout_TV{}); print("\n"); | |
print(static_cast<Atom const&>(copy)); | |
} | |
template <class TiledCopy, class ThrIdx> | |
CUTE_HOST_DEVICE | |
void | |
print(ThrCopy<TiledCopy, ThrIdx> const& thr_copy) | |
{ | |
print("ThrCopy\n"); | |
print(" ThrIdx: "); print(thr_copy.thr_idx_); print("\n"); | |
print(TiledCopy{}); | |
} | |
template <class... Args> | |
CUTE_HOST_DEVICE | |
auto | |
print_latex(TiledCopy<Args...> const& copy) | |
{ | |
auto [layoutS_MN, thrID_S] = copy.get_layoutS_MN(); | |
auto [layoutD_MN, thrID_D] = copy.get_layoutD_MN(); | |
print_latex_copy(layoutS_MN, thrID_S, | |
layoutD_MN, thrID_D); | |
} | |
// MNK Copy Layout to Latex TIKZ -- 8-value color coded by thread | |
template <class LayoutS, class ThrIDS, | |
class LayoutD, class ThrIDD> | |
CUTE_HOST_DEVICE | |
void | |
print_latex_copy(LayoutS const& S, ThrIDS const& TS, // (m,n) -> (tid,vid) and tid -> thr_idx | |
LayoutD const& D, ThrIDD const& TD) // (m,n) -> (tid,vid) and tid -> thr_idx | |
{ | |
CUTE_STATIC_ASSERT_V(rank(S) == Int<2>{}); | |
CUTE_STATIC_ASSERT_V(rank(D) == Int<2>{}); | |
assert(size<0>(S) == size<0>(D)); | |
assert(size<1>(S) == size<1>(D)); | |
char const* latex_header = | |
"\\documentclass{standalone}\n" | |
"\\usepackage{tikz}\n" | |
"\\usetikzlibrary{external}\n" | |
"\\tikzexternalize\n" | |
"\\begin{document}\n" | |
"\\begin{tikzpicture}[x={(0cm,-1cm)},y={(1cm,0cm)},box/.style={rectangle,draw=black,thick,minimum size=1cm,anchor=center}]\n\n"; | |
char const* latex_footer = | |
"\\end{tikzpicture}\n" | |
"\\end{document}\n"; | |
char const* color_map[8] = {"{rgb,255:red,175;green,175;blue,255}", | |
"{rgb,255:red,175;green,255;blue,175}", | |
"{rgb,255:red,255;green,255;blue,175}", | |
"{rgb,255:red,255;green,175;blue,175}", | |
"{rgb,255:red,210;green,210;blue,255}", | |
"{rgb,255:red,210;green,255;blue,210}", | |
"{rgb,255:red,255;green,255;blue,210}", | |
"{rgb,255:red,255;green,210;blue,210}",}; | |
// Header | |
printf("%% LayoutS: "); print(S); printf("\n"); | |
printf("%% ThrIDS : "); print(TS); printf("\n"); | |
printf("%% LayoutD: "); print(D); printf("\n"); | |
printf("%% ThrIDD : "); print(TD); printf("\n\n"); | |
printf(latex_header); | |
// S starting at 0,0 | |
for (int i = 0; i < size<0>(S); ++i) { | |
for (int j = 0; j < size<1>(S); ++j) { | |
int thrid = S(i,j) % size(TS); | |
int val_idx = S(i,j) / size(TS); | |
int thr_idx = TS(thrid); | |
printf("\\node[box,fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n", | |
color_map[thr_idx % 8], | |
i, j, | |
thr_idx, val_idx); | |
} | |
} | |
// D starting at 0,size<1>(S)+3 | |
for (int i = 0; i < size<0>(D); ++i) { | |
for (int j = 0; j < size<1>(D); ++j) { | |
int thrid = D(i,j) % size(TD); | |
int val_idx = D(i,j) / size(TD); | |
int thr_idx = TD(thrid); | |
printf("\\node[box,fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n", | |
color_map[thr_idx % 8], | |
i, j + size<1>(S) + 3, | |
thr_idx, val_idx); | |
} | |
} | |
// S Labels | |
for (int i = 0, j = -1; i < size<0>(S); ++i) { | |
printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", i, j, i); | |
} | |
for (int j = 0, i = -1; j < size<1>(S); ++j) { | |
printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", i, j, j); | |
} | |
// D Labels | |
for (int i = 0, j = size<1>(D); i < size<0>(S); ++i) { | |
printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", i, j + size<1>(S) + 3, i); | |
} | |
for (int j = 0, i = -1; j < size<1>(D); ++j) { | |
printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", i, j + size<1>(S) + 3, j); | |
} | |
// Footer | |
printf(latex_footer); | |
} | |
} // end namespace cute | |
//////////////////////////////////////////////////////////////////////////////////////////////////// | |
// Config | |
//////////////////////////////////////////////////////////////////////////////////////////////////// | |