Spaces:
Sleeping
Sleeping
File size: 36,547 Bytes
e45d058 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 |
/***************************************************************************************************
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include <cute/config.hpp>
#include <cute/util/type_traits.hpp>
#include <cute/numeric/integral_constant.hpp>
#include <cute/numeric/integer_sequence.hpp>
#include <cute/container/tuple.hpp>
#include <cute/container/array_aligned.hpp>
#include <cute/container/array_subbyte.hpp>
#include <cute/pointer.hpp>
#include <cute/layout.hpp>
namespace cute
{
//
// Engine -- owning or non-owning data store
//
// concept Engine {
// using iterator = ;
// using value_type = ;
// using element_type = ;
// using reference = ;
// iterator begin();
// };
template <class T, int N>
struct ArrayEngine
{
using Storage = typename conditional<(sizeof_bits<T>::value % 8 == 0),
array_aligned<T,N>,
array_subbyte<T,N>>::type;
using iterator = typename Storage::iterator;
using reference = typename iterator_traits<iterator>::reference;
using element_type = typename iterator_traits<iterator>::element_type;
using value_type = typename iterator_traits<iterator>::value_type;
Storage storage_;
CUTE_HOST_DEVICE constexpr auto begin() const { return storage_.begin(); }
CUTE_HOST_DEVICE constexpr auto begin() { return storage_.begin(); }
};
template <class Iterator>
struct ViewEngine
{
using iterator = Iterator;
using reference = typename iterator_traits<iterator>::reference;
using element_type = typename iterator_traits<iterator>::element_type;
using value_type = typename iterator_traits<iterator>::value_type;
iterator storage_;
CUTE_HOST_DEVICE constexpr iterator const& begin() const { return storage_; }
CUTE_HOST_DEVICE constexpr iterator & begin() { return storage_; }
};
template <class Iterator>
struct ConstViewEngine
{
using iterator = Iterator;
using reference = typename iterator_traits<iterator>::reference;
using element_type = typename iterator_traits<iterator>::element_type;
using value_type = typename iterator_traits<iterator>::value_type;
iterator storage_;
CUTE_HOST_DEVICE constexpr iterator const& begin() const { return storage_; }
};
//
// Tensor
//
template <class Engine, class Layout>
struct Tensor
{
using iterator = typename Engine::iterator;
using value_type = typename Engine::value_type;
using element_type = typename Engine::element_type;
using reference = typename Engine::reference;
using engine_type = Engine;
using layout_type = Layout;
CUTE_HOST_DEVICE constexpr
Tensor() {}
template <class Ptr>
CUTE_HOST_DEVICE constexpr
Tensor(Ptr const& ptr, Layout const& layout)
: rep_(layout, ptr) {
}
//
// Accessors
//
static constexpr int rank = Layout::rank;
CUTE_HOST_DEVICE constexpr
decltype(auto)
tensor() const {
return *this;
}
CUTE_HOST_DEVICE constexpr
decltype(auto)
layout() const {
return get<0>(rep_);
}
CUTE_HOST_DEVICE constexpr
decltype(auto)
engine() const {
return get<1>(rep_);
}
CUTE_HOST_DEVICE constexpr
decltype(auto)
engine() {
return get<1>(rep_);
}
CUTE_HOST_DEVICE constexpr
decltype(auto)
data() const {
return engine().begin();
}
CUTE_HOST_DEVICE constexpr
decltype(auto)
data() {
return engine().begin();
}
CUTE_HOST_DEVICE constexpr
decltype(auto)
shape() const {
return layout().shape();
}
CUTE_HOST_DEVICE constexpr
auto
size() const {
return cute::size(shape());
}
CUTE_HOST_DEVICE constexpr
decltype(auto)
stride() const {
return layout().stride();
}
//
// Indexing op() and op[]
//
// Index into this tensor like an array by computing the offset via layout()
template <class Coord>
CUTE_HOST_DEVICE constexpr
decltype(auto)
operator[](Coord const& coord) {
return data()[layout()(coord)];
}
template <class Coord>
CUTE_HOST_DEVICE constexpr
decltype(auto)
operator[](Coord const& coord) const {
return data()[layout()(coord)];
}
template <class Coord>
CUTE_HOST_DEVICE constexpr
decltype(auto)
operator()(Coord const& coord) {
if constexpr (has_underscore<Coord>::value) {
auto const& [sliced_layout,offset] = slice_and_offset(coord, layout());
return make_tensor(data() + offset, sliced_layout);
} else {
return data()[layout()(coord)];
}
CUTE_GCC_UNREACHABLE;
}
template <class Coord>
CUTE_HOST_DEVICE constexpr
decltype(auto)
operator()(Coord const& coord) const {
if constexpr (has_underscore<Coord>::value) {
auto const& [sliced_layout,offset] = slice_and_offset(coord, layout());
return make_tensor(data() + offset, sliced_layout);
} else {
return data()[layout()(coord)];
}
CUTE_GCC_UNREACHABLE;
}
// op() convenience function for multi-dimensional coordinates
template <class Coord0, class Coord1, class... Coords>
CUTE_HOST_DEVICE constexpr
decltype(auto)
operator()(Coord0 const& c0, Coord1 const& c1, Coords const&... cs) {
return operator()(make_coord(c0,c1,cs...));
}
template <class Coord0, class Coord1, class... Coords>
CUTE_HOST_DEVICE constexpr
decltype(auto)
operator()(Coord0 const& c0, Coord1 const& c1, Coords const&... cs) const {
return operator()(make_coord(c0,c1,cs...));
}
//
// Compose
//
template <class... Layouts>
CUTE_HOST_DEVICE constexpr
auto
compose(Layouts const&... layouts) {
return make_tensor(data(), layout().compose(layouts...));
}
template <class... Layouts>
CUTE_HOST_DEVICE constexpr
auto
compose(Layouts const&... layouts) const {
return make_tensor(data(), layout().compose(layouts...));
}
//
// Tile
//
template <class... Layouts>
CUTE_HOST_DEVICE constexpr
auto
tile(Layouts const&... layouts) {
return make_tensor(data(), layout().tile(layouts...));
}
template <class... Layouts>
CUTE_HOST_DEVICE constexpr
auto
tile(Layouts const&... layouts) const {
return make_tensor(data(), layout().tile(layouts...));
}
//
// Utility
//
template <class Int,
__CUTE_REQUIRES(is_integral<Int>::value)>
CUTE_HOST_DEVICE constexpr
auto
get_1d_coord(Int const& linear_idx) const {
return layout().get_1d_coord(linear_idx);
}
template <class Int,
__CUTE_REQUIRES(is_integral<Int>::value)>
CUTE_HOST_DEVICE constexpr
auto
get_hier_coord(Int const& linear_idx) const {
return layout().get_hier_coord(linear_idx);
}
template <class Int,
__CUTE_REQUIRES(is_integral<Int>::value)>
CUTE_HOST_DEVICE constexpr
auto
get_flat_coord(Int const& linear_idx) const {
return layout().get_flat_coord(linear_idx);
}
cute::tuple<layout_type, engine_type> rep_;
};
template <class T>
struct is_tensor : false_type {};
template <class Engine, class Layout>
struct is_tensor<Tensor<Engine,Layout>> : true_type {};
template <class T>
constexpr bool is_tensor_v = is_tensor<T>::value;
// Customization point for creation of owning and non-owning Tensors
template <class T>
struct MakeTensor
{
template <class Layout,
__CUTE_REQUIRES(not has_dereference<T>::value &&
is_layout<Layout>::value)>
CUTE_HOST_DEVICE constexpr auto
operator()(Layout const& layout) const
{
static_assert(is_static<Layout>::value, "Dynamic owning tensors not supported");
using Engine = ArrayEngine<T, cosize_v<Layout>>;
return Tensor<Engine,Layout>();
}
template <class Layout,
__CUTE_REQUIRES(has_dereference<T>::value &&
is_layout<Layout>::value)>
CUTE_HOST_DEVICE constexpr auto
operator()(T const& iter, Layout const& layout)
{
using Engine = ViewEngine<T>;
return Tensor<Engine,Layout>(iter, layout);
}
template <class LayoutArg, class... LayoutArgs,
__CUTE_REQUIRES(not is_layout<LayoutArg>::value)>
CUTE_HOST_DEVICE constexpr auto
operator()(LayoutArg const& arg, LayoutArgs const&... args) const
{
return operator()(make_layout(arg, args...));
}
template <class LayoutArg, class... LayoutArgs,
__CUTE_REQUIRES(not is_layout<LayoutArg>::value)>
CUTE_HOST_DEVICE constexpr auto
operator()(T const& iter, LayoutArg const& arg, LayoutArgs const&... args)
{
return operator()(iter, make_layout(arg, args...));
}
};
//
// make_tensor
//
// Make an owning Tensor that will allocate a static array
// e.g. make_tensor<float>(Int<12>{})
template <class T, class... Args>
CUTE_HOST_DEVICE constexpr
auto
make_tensor(Args const&... args)
{
return MakeTensor<T>{}(args...);
}
// Make a non-owning Tensor that will use a pointer (view)
// e.g. make_tensor(vec.data(), 12)
template <class Iterator, class... Args>
CUTE_HOST_DEVICE constexpr
auto
make_tensor(Iterator const& iter, Args const&... args)
{
return MakeTensor<Iterator>{}(iter, args...);
}
//
// make_tensor_like
// Make a register tensor the same type and shape and (if possible) order as another tensor
//
template <class NewT, class Layout>
CUTE_HOST_DEVICE constexpr
auto
make_tensor_like(Layout const& layout)
{
return make_tensor<NewT>(make_layout_like(layout));
}
template <class NewT, class Engine, class Layout>
CUTE_HOST_DEVICE constexpr
auto
make_tensor_like(Tensor<Engine,Layout> const& tensor)
{
return make_tensor_like<NewT>(tensor.layout());
}
template <class Engine, class Layout>
CUTE_HOST_DEVICE constexpr
auto
make_tensor_like(Tensor<Engine,Layout> const& tensor)
{
return make_tensor_like<typename Engine::value_type>(tensor.layout());
}
//
// make_fragment_like --
// Make a tensor the same shape and (if possible) order as another tensor, with special
// consideration of the 0th mode. The 0th mode is commonly used for MMA_Atoms or Copy_Atoms
// so this allocates the 0th mode with LayoutLeft regardless of the reference layout.
//
template <class NewT, class Layout>
CUTE_HOST_DEVICE constexpr
auto
make_fragment_like(Layout const& layout)
{
return make_tensor<NewT>(make_fragment_like(layout));
}
template <class NewT, class Engine, class Layout>
CUTE_HOST_DEVICE constexpr
auto
make_fragment_like(Tensor<Engine,Layout> const& tensor)
{
return make_fragment_like<NewT>(tensor.layout());
}
template <class Engine, class Layout>
CUTE_HOST_DEVICE constexpr
auto
make_fragment_like(Tensor<Engine,Layout> const& tensor)
{
return make_fragment_like<typename Engine::value_type>(tensor.layout());
}
//
// make_counting_tensor
// Make a tensor from a layout by binding it to a counting iter with 0-offset of the same profile as the codomain.
//
template <class Layout, __CUTE_REQUIRES(is_layout<Layout>::value)>
CUTE_HOST_DEVICE constexpr
auto
make_counting_tensor(Layout const& layout)
{
return make_tensor(make_inttuple_iter(repeat_like(coshape(layout), Int<0>{})), layout);
}
//
// make_identity_tensor
// Make a tensor that maps coordinates within a shape to themselves.
//
template <class Shape>
CUTE_HOST_DEVICE constexpr
auto
make_identity_tensor(Shape const& shape)
{
return make_counting_tensor(make_identity_layout(shape));
}
//
// Utilities
//
// Return the subtensor of a mode
template <class Tensor,
__CUTE_REQUIRES(is_tensor<remove_cvref_t<Tensor>>::value)>
CUTE_HOST_DEVICE constexpr
decltype(auto)
tensor(Tensor&& tensor)
{
return static_cast<Tensor&&>(tensor);
}
template <int I, int... Is, class Tensor,
__CUTE_REQUIRES(is_tensor<remove_cvref_t<Tensor>>::value)>
CUTE_HOST_DEVICE constexpr
decltype(auto)
tensor(Tensor&& tensor)
{
return make_tensor(static_cast<Tensor&&>(tensor).data(), get<I,Is...>(tensor.layout()));
}
// Return the layout of a mode
template <int... Is, class Engine, class Layout>
CUTE_HOST_DEVICE constexpr
decltype(auto)
layout(Tensor<Engine,Layout> const& tensor)
{
return layout<Is...>(tensor.layout());
}
// Return the shape of a mode
template <int... Is, class Engine, class Layout>
CUTE_HOST_DEVICE constexpr
decltype(auto)
shape(Tensor<Engine,Layout> const& tensor)
{
return shape<Is...>(tensor.layout());
}
// Return the stride of a mode
template <int... Is, class Engine, class Layout>
CUTE_HOST_DEVICE constexpr
decltype(auto)
stride(Tensor<Engine,Layout> const& tensor)
{
return stride<Is...>(tensor.layout());
}
// Return the number of elements in a mode
template <int... Is, class Engine, class Layout>
CUTE_HOST_DEVICE constexpr
decltype(auto)
size(Tensor<Engine,Layout> const& tensor)
{
return size<Is...>(tensor.layout());
}
// Return the rank of a mode
template <int... Is, class Engine, class Layout>
CUTE_HOST_DEVICE constexpr
auto
rank(Tensor<Engine,Layout> const& tensor)
{
return rank<Is...>(tensor.layout());
}
// Return the depth of a mode
template <int... Is, class Engine, class Layout>
CUTE_HOST_DEVICE constexpr
auto
depth(Tensor<Engine, Layout> const& tensor)
{
return depth<Is...>(tensor.layout());
}
//
// Operations to manipulate Tensors like a Layout
//
template <class Tensor,
__CUTE_REQUIRES(is_tensor<remove_cvref_t<Tensor>>::value)>
CUTE_HOST_DEVICE constexpr
auto
flatten(Tensor&& tensor)
{
return make_tensor(static_cast<Tensor&&>(tensor).data(), flatten(tensor.layout()));
}
template <class Tensor,
__CUTE_REQUIRES(is_tensor<remove_cvref_t<Tensor>>::value)>
CUTE_HOST_DEVICE constexpr
auto
coalesce(Tensor&& tensor)
{
return make_tensor(static_cast<Tensor&&>(tensor).data(), coalesce(tensor.layout()));
}
template <class Tensor, class Profile,
__CUTE_REQUIRES(is_tensor<remove_cvref_t<Tensor>>::value)>
CUTE_HOST_DEVICE constexpr
auto
coalesce(Tensor&& tensor, Profile const& profile)
{
return make_tensor(static_cast<Tensor&&>(tensor).data(), coalesce(tensor.layout(), profile));
}
template <class Tensor,
__CUTE_REQUIRES(is_tensor<remove_cvref_t<Tensor>>::value)>
CUTE_HOST_DEVICE constexpr
auto
filter_zeros(Tensor&& tensor)
{
return make_tensor(static_cast<Tensor&&>(tensor).data(), filter_zeros(tensor.layout()));
}
template <class Tensor,
__CUTE_REQUIRES(is_tensor<remove_cvref_t<Tensor>>::value)>
CUTE_HOST_DEVICE constexpr
auto
filter(Tensor&& tensor)
{
return make_tensor(static_cast<Tensor&&>(tensor).data(), filter(tensor.layout()));
}
template <class Tensor, class Profile,
__CUTE_REQUIRES(is_tensor<remove_cvref_t<Tensor>>::value)>
CUTE_HOST_DEVICE constexpr
auto
filter(Tensor&& tensor, Profile const& profile)
{
return make_tensor(static_cast<Tensor&&>(tensor).data(), filter(tensor.layout(), profile));
}
// Return a tensor with the same shape as input but offset by a given coordinate
template <class Coord, class Tensor,
__CUTE_REQUIRES(is_tensor<remove_cvref_t<Tensor>>::value)>
CUTE_HOST_DEVICE constexpr
auto
domain_offset(Coord const& coord, Tensor&& tensor)
{
auto [layout, ptr_offset] = domain_offset(coord, tensor.layout());
return make_tensor(static_cast<Tensor&&>(tensor).data() + ptr_offset, layout);
}
// Group the modes [B,E) into a single mode
// e.g. group<2,4>(make_tensor<int>(Layout<Shape<_1,_2,_3,_4,_5,_6>>{}))
// => make_tensor<int>(Layout<Shape<_1,_2,Shape<_3,_4>,_5,_6>>{})
template <int B, int E, class Tensor,
__CUTE_REQUIRES(is_tensor<remove_cvref_t<Tensor>>::value)>
CUTE_HOST_DEVICE constexpr
auto
group_modes(Tensor&& tensor)
{
return make_tensor(static_cast<Tensor&&>(tensor).data(),
group<B,E>(tensor.layout()));
}
// Return the subtensor of a range of modes
template <int B, int E, class Tensor,
__CUTE_REQUIRES(is_tensor<remove_cvref_t<Tensor>>::value)>
CUTE_HOST_DEVICE constexpr
decltype(auto)
take(Tensor&& tensor)
{
return make_tensor(static_cast<Tensor&&>(tensor).data(), take<B,E>(tensor.layout()));
}
//
// Recast
//
// NOTE: This is very dangerous to do
// -- doesn't check dynamic integer divisibility
// -- doesn't check alignment
template <class NewType, class Tensor>
CUTE_HOST_DEVICE constexpr
auto
recast(Tensor&& tensor)
{
using OldType = typename remove_cvref_t<Tensor>::value_type;
auto old_layout = tensor.layout();
auto new_layout = recast_layout<OldType,NewType>(old_layout);
// If this is an upcast of a normal Layout with static negative strides, then offset as well
if constexpr (sizeof(OldType) < sizeof(NewType) && not is_composed_layout<decltype(old_layout)>::value) {
auto shape_diff = transform(flatten(old_layout.shape()), flatten(new_layout.shape()), minus{});
auto extent_diff = transform(shape_diff, flatten(old_layout.stride()), multiplies{});
auto offset = fold(extent_diff, Int<0>{}, [](auto const& i, auto const& a) { return i + cute::min(a,Int<0>{}); });
return make_tensor(recast_ptr<NewType>(static_cast<Tensor&&>(tensor).data() + offset), new_layout);
} else {
return make_tensor(recast_ptr<NewType>(static_cast<Tensor&&>(tensor).data() ), new_layout);
}
CUTE_GCC_UNREACHABLE;
}
//
// max_common_vector
//
/* Return Int<N> such that N is the maximum number of contiguous elements
* that logically correspond in the tensors of @a a and @a b. This is,
* the number of elements that could reasonably be vectorized into a single load/store.
*
* @returns Int<N> with N >= 0
*
* A return value of Int<0> indicates that no such conclusion can be made and no
* vectorization should be attempted.
*
* Note that the return value does NOT include alignment concerns such as the pointer value and
* the divisbility of dynamic strides.
*/
template <class SrcEngine, class SrcLayout,
class DstEngine, class DstLayout>
CUTE_HOST_DEVICE constexpr
auto
max_common_vector(Tensor<SrcEngine,SrcLayout> const& a,
Tensor<DstEngine,DstLayout> const& b)
{
using SrcType = typename Tensor<SrcEngine,SrcLayout>::value_type;
using DstType = typename Tensor<DstEngine,DstLayout>::value_type;
using SrcRef = typename Tensor<SrcEngine,SrcLayout>::reference;
using DstRef = typename Tensor<SrcEngine,SrcLayout>::reference;
// Determine if vectorization candidates at all
if constexpr (// Should be the same value_types, else the copy is also performing a cast
sizeof_bits_v<SrcType> == sizeof_bits_v<DstType> &&
// The types should be trivially copyable so that vectorization is valid
is_trivially_copyable<SrcType>::value &&
is_trivially_copyable<DstType>::value &&
// Should be load/storing real data, rather than implicit iterators or such
is_reference<SrcRef>::value &&
is_reference<DstRef>::value)
{
return max_common_vector(a.layout(), b.layout());
} else {
return Int<0>{};
}
CUTE_GCC_UNREACHABLE;
}
/* Return a layout that points to the maximum number of contiguous elements
* that logically correspond in the tensors of @a a and @a b. This is,
* the elements that could reasonably be "vectorized" into a single load/store.
*
* @returns Layout R such that composition(a.layout(), R) and composition(b.layout(), R)
* are both identity Layouts.
*
* Note that the returned layout does NOT include alignment concerns such as the pointer value and
* the divisbility of dynamic strides.
*/
template <class SrcEngine, class SrcLayout,
class DstEngine, class DstLayout>
CUTE_HOST_DEVICE constexpr
auto
max_common_layout(Tensor<SrcEngine,SrcLayout> const& a,
Tensor<DstEngine,DstLayout> const& b)
{
using SrcType = typename Tensor<SrcEngine,SrcLayout>::value_type;
using DstType = typename Tensor<DstEngine,DstLayout>::value_type;
using SrcRef = typename Tensor<SrcEngine,SrcLayout>::reference;
using DstRef = typename Tensor<SrcEngine,SrcLayout>::reference;
// Determine if vectorization candidates at all
if constexpr (// Should be the same value_types, else the copy is also performing a cast
sizeof_bits_v<SrcType> == sizeof_bits_v<DstType> &&
// The types should be trivially copyable so that vectorization is valid
is_trivially_copyable<SrcType>::value &&
is_trivially_copyable<DstType>::value &&
// Should be load/storing real data, rather than implicit iterators or such
is_reference<SrcRef>::value &&
is_reference<DstRef>::value)
{
return max_common_layout(a.layout(), b.layout());
} else {
return Layout<_1,_0>{};
}
CUTE_GCC_UNREACHABLE;
}
//
// Key algebraic operations -- Divide and Product
//
// Apply a Tiler to the Tensor.
//
// Consider a Tensor with shape (A,B,x,y)
// And a Tiler that is:
//
// * A Layout with shape (BLK_A,BLK_B)
// ** Result Tensor shape ((BLK_A,BLK_B),Rest).
// ** That is, the Tensor and Tile are treated as 1D for the tiling.
// ** See logical_divide(Layout,Layout)
//
// * A Tile<Layout...> with shape <BLK_A,BLK_B>
// ** Result Tensor shape ((BLK_A,a),(BLK_B,b),x,y).
// ** Each mode of the Tile<Layout...> is applied to the corresponding mode of the Tensor.
// ** See logical_divide(Layout,Tuple)
//
// * A Shape (BLK_A,BLK_B)
// ** Result Tensor shape ((BLK_A,a),(BLK_B,b),x,y).
// ** Equivalent to applying Tile<BLK_A:_1,BLK_B:_1>.
// ** See logical_divide(Layout,Tuple) and logical_divide(Layout,Int)
//
// Note that the Tile<Layout...>/Shape Tilers must be weakly_congruent to the Tensor
template <class Tensor, class Tiler,
__CUTE_REQUIRES(is_tensor<remove_cvref_t<Tensor>>::value)>
CUTE_HOST_DEVICE constexpr
auto
logical_divide(Tensor && tensor,
Tiler const& tiler) // Layout or Tile<Layout...> or Shape
{
return make_tensor(static_cast<Tensor&&>(tensor).data(),
logical_divide(tensor.layout(), tiler));
}
// zipped_divide is logical_divide with Tiler modes and Rest modes gathered together: (Tiler,Rest)
// When Tiler is Layout, this has no effect as logical_divide results in the same.
// When Tiler is Tile<Layout...> or Shape, this zips modes into standard form ((BLK_A,BLK_B),(a,b,x,y))
template <class Tensor, class Tiler,
__CUTE_REQUIRES(is_tensor<remove_cvref_t<Tensor>>::value)>
CUTE_HOST_DEVICE constexpr
auto
zipped_divide(Tensor && tensor,
Tiler const& tiler) // Layout or Tile<Layout...> or Shape
{
return make_tensor(static_cast<Tensor&&>(tensor).data(),
zipped_divide(tensor.layout(), tiler));
}
// tiled_divide is zipped_divide with the second output mode flattened ((BLK_A,BLK_B),a,b,x,y)
template <class Tensor, class Tiler,
__CUTE_REQUIRES(is_tensor<remove_cvref_t<Tensor>>::value)>
CUTE_HOST_DEVICE constexpr
auto
tiled_divide(Tensor && tensor,
Tiler const& tiler) // Layout or Tile<Layout...> or Shape
{
return make_tensor(static_cast<Tensor&&>(tensor).data(),
tiled_divide(tensor.layout(), tiler));
}
// flat_divide is zipped_divide with the both modes flattened (BLK_A,BLK_B,a,b,x,y)
template <class Tensor, class Tiler,
__CUTE_REQUIRES(is_tensor<remove_cvref_t<Tensor>>::value)>
CUTE_HOST_DEVICE constexpr
auto
flat_divide(Tensor && tensor,
Tiler const& tiler) // Layout or Tile<Layout...> or Shape
{
return make_tensor(static_cast<Tensor&&>(tensor).data(),
flat_divide(tensor.layout(), tiler));
}
// logical_product on a Tensor doesn't make sense since it often increases cosize
// though this might make sense for creating Tensors with broadcasted (stride-0) modes
//
// Tensor partitioning utilities
//
// Apply a Tiler to the Tensor, then slice out one of those tiles by slicing into the "Rest" modes.
// With an inner_partition, you get everything that's inside the Tiler. Everything that the Tiler is pointing to.
// Split the modes of tensor according to the Tiler
// zipped_divide returns something like ((BLK_A,BLK_B,...),(a,b,...,x,y))
// Then slice into the second mode (the "Rest" mode) with Coord
template <class Tensor, class Tiler, class Coord,
__CUTE_REQUIRES(is_tensor<remove_cvref_t<Tensor>>::value)>
CUTE_HOST_DEVICE constexpr
auto
inner_partition(Tensor && tensor,
Tiler const& tiler,
Coord const& coord)
{
auto tensor_tiled = zipped_divide(static_cast<Tensor&&>(tensor), tiler);
constexpr int R0 = decltype(rank<0>(tensor_tiled))::value;
// The coord slices into the second mode (the "rest" mode), flatten the first
if constexpr (is_tuple<Coord>::value) {
// Append trailing modes if coord is tuple
constexpr int R1 = decltype(rank<1>(tensor_tiled))::value;;
return tensor_tiled(repeat<R0>(_), append<R1>(coord,_));
} else {
// Flat indexing if coord is not tuple
return tensor_tiled(repeat<R0>(_), coord);
}
}
// Apply a Tiler to the Tensor, then slice out the remainder by slicing into the "Tile" modes.
// With an outer_partition, you get everything that's outside the Tiler. The layout of the Tile in the Tensor.
// Split the modes of tensor according to the Tiler
// zipped_divide returns something like ((BLK_A,BLK_B,...),(a,b,...,x,y))
// Then slice into the first mode (the "Tile" mode) with Coord
template <class Tensor, class Tiler, class Coord,
__CUTE_REQUIRES(is_tensor<remove_cvref_t<Tensor>>::value)>
CUTE_HOST_DEVICE constexpr
auto
outer_partition(Tensor && tensor,
Tiler const& tiler,
Coord const& coord)
{
auto tensor_tiled = zipped_divide(static_cast<Tensor&&>(tensor), tiler);
constexpr int R1 = decltype(rank<1>(tensor_tiled))::value;
// The coord slices into the first mode (the "tile" mode), flatten the second
if constexpr (is_tuple<Coord>::value) {
// Append trailing modes if coord is tuple
constexpr int R0 = decltype(rank<0>(tensor_tiled))::value;
return tensor_tiled(append<R0>(coord,_), repeat<R1>(_));
} else {
// Flat indexing if coord is not tuple
return tensor_tiled(coord, repeat<R1>(_));
}
}
// Tile a tensor according to @a tiler and use @a coord to index into the remainder, keeping the tile.
// This is typical at the CTA level where tiles of data are extracted:
// Tensor data = ... // ( M, N)
// Tensor cta_data = local_tile(data, Shape<_32,_64>{}, make_coord(blockIdx.x,blockIdx.y)); // (_32,_64)
template <class Tensor, class Tiler, class Coord,
__CUTE_REQUIRES(is_tensor<remove_cvref_t<Tensor>>::value)>
CUTE_HOST_DEVICE constexpr
auto
local_tile(Tensor && tensor,
Tiler const& tiler, // tiler to apply
Coord const& coord) // coord to slice into "remainder"
{
return inner_partition(static_cast<Tensor&&>(tensor),
tiler,
coord);
}
// Same as above, but with a projection parameter to strip out unwanted tiling modes for convenience
// when using projections of the same tiler.
// This is typical at the CTA level where tiles of data are extracted as projections:
// Tensor dataA = ... // (M,K)
// Tensor dataB = ... // (N,K)
// Tensor dataC = ... // (M,N)
// auto cta_tiler = Shape<_32, _64, _4>{};
// auto cta_coord = make_coord(blockIdx.x, blockIdx.y, _);
// Tensor ctaA = local_tile(dataA, cta_tiler, cta_coord, Step<_1, X,_1>{}); // (_32,_4,k)
// Tensor ctaB = local_tile(dataA, cta_tiler, cta_coord, Step< X,_1,_1>{}); // (_64,_4,k)
// Tensor ctaC = local_tile(dataA, cta_tiler, cta_coord, Step<_1,_1, X>{}); // (_32,_64)
template <class Tensor, class Tiler, class Coord, class Proj,
__CUTE_REQUIRES(is_tensor<remove_cvref_t<Tensor>>::value)>
CUTE_HOST_DEVICE
auto
local_tile(Tensor && tensor,
Tiler const& tiler, // tiler to apply
Coord const& coord, // coord to slice into "remainder"
Proj const& proj) // projection to apply to tiler and coord
{
return local_tile(static_cast<Tensor&&>(tensor),
dice(proj, tiler),
dice(proj, coord));
}
// Tile a tensor according to the flat shape of a layout that provides the coordinate of the target index.
// This is typical at the Thread level where data is partitioned across repeated patterns of threads:
// Tensor data = ... // (_16,_64)
// Tensor thr_data = local_partition(data, Layout<Shape<_2,_16>>{}, thr_idx); // ( _8, _4)
template <class Tensor, class LShape, class LStride, class Index,
__CUTE_REQUIRES(is_tensor<remove_cvref_t<Tensor>>::value)>
CUTE_HOST_DEVICE
auto
local_partition(Tensor && tensor,
Layout<LShape,LStride> const& tile, // coord -> index
Index const& index) // index to slice for
{
static_assert(is_integral<Index>::value);
return outer_partition(static_cast<Tensor&&>(tensor),
product_each(shape(tile)),
tile.get_flat_coord(index));
}
// Same as above, but with a projection parameter to strip out unwanted tiling modes for convenience
// when using projections of the same tiler.
// This is typical at the Thread level where data is partitioned across projected layouts of threads:
// Tensor dataA = ... // (M,K)
// Tensor dataB = ... // (N,K)
// Tensor dataC = ... // (M,N)
// auto thr_layout = Layout<Shape<_2,_16,_1>, Stride<_16,_1,_0>>{};
// Tensor thrA = local_partition(dataA, thr_layout, thr_idx, Step<_1, X,_1>{}); // (M/2,K/1)
// Tensor thrB = local_partition(dataB, thr_layout, thr_idx, Step< X,_1,_1>{}); // (N/16,K/1)
// Tensor thrC = local_partition(dataC, thr_layout, thr_idx, Step<_1,_1, X>{}); // (M/2,N/16)
template <class Tensor, class LShape, class LStride, class Index, class Projection,
__CUTE_REQUIRES(is_tensor<remove_cvref_t<Tensor>>::value)>
CUTE_HOST_DEVICE
auto
local_partition(Tensor && tensor,
Layout<LShape,LStride> const& tile, // coord -> index
Index const& index, // index to slice for
Projection const& proj)
{
return local_partition(static_cast<Tensor&&>(tensor),
dice(proj, tile),
index);
}
//
// Display utilities
//
template <class Engine, class Layout>
CUTE_HOST_DEVICE void print(Tensor<Engine,Layout> const& tensor)
{
print(tensor.data()); print(" o "); print(tensor.layout());
}
template <class Engine, class Layout>
CUTE_HOST_DEVICE void print_tensor(Tensor<Engine,Layout> const& tensor, bool print_type = true)
{
if (print_type) {
print(tensor); print(":\n");
}
if constexpr (Layout::rank == 1)
{
for (int m = 0; m < size(tensor); ++m) {
pretty_print(tensor(m));
printf("\n");
}
} else
if constexpr (Layout::rank == 2)
{
for (int m = 0; m < size<0>(tensor); ++m) {
for (int n = 0; n < size<1>(tensor); ++n) {
pretty_print(tensor(m,n));
}
printf("\n");
}
} else
if constexpr (Layout::rank == 3)
{
print_tensor(tensor(_,_,0), false);
for (int k = 1; k < size<2>(tensor); ++k) {
for (int i = 0; i < 5*size<1>(tensor); ++i) { print("-"); } print("\n");
print_tensor(tensor(_,_,k), false);
}
} else
if constexpr (Layout::rank == 4)
{
print_tensor(tensor(_,_,_,0), false);
for (int p = 1; p < size<3>(tensor); ++p) {
for (int i = 0; i < 5*size<1>(tensor); ++i) { print("="); } print("\n");
print_tensor(tensor(_,_,_,p), false);
}
}
}
#if !defined(__CUDACC_RTC__)
template <class Engine, class Layout>
CUTE_HOST std::ostream& print_tensor_os(std::ostream& os, Tensor<Engine,Layout> const& tensor)
{
int digits = 9;
if constexpr (Layout::rank == 1)
{
for (int m = 0; m < size(tensor); ++m) {
os << std::setw(digits) << tensor(m) << std::endl;
}
} else
if constexpr (Layout::rank == 2)
{
for (int m = 0; m < size<0>(tensor); ++m) {
for (int n = 0; n < size<1>(tensor); ++n) {
os << std::setw(digits) << tensor(m,n);
}
os << std::endl;
}
} else
if constexpr (Layout::rank == 3)
{
print_tensor_os(os, tensor(_,_,0));
for (int k = 1; k < size<2>(tensor); ++k) {
for (int i = 0; i < digits*size<1>(tensor); ++i) { os << "-"; } os << std::endl;
print_tensor_os(os, tensor(_,_,k));
}
} else
if constexpr (Layout::rank == 4)
{
print_tensor_os(os, tensor(_,_,_,0));
for (int p = 1; p < size<3>(tensor); ++p) {
for (int i = 0; i < digits*size<1>(tensor); ++i) { os << "="; } os << std::endl;
print_tensor_os(os, tensor(_,_,_,p));
}
}
return os;
}
template <class Engine, class Layout>
CUTE_HOST std::ostream& operator<<(std::ostream& os, Tensor<Engine,Layout> const& tensor)
{
os << tensor.layout() << std::endl;
return print_tensor_os(os, tensor);
}
#endif // !defined(__CUDACC_RTC__)
} // end namespace cute
//
// Extended Engines
//
#include <cute/pointer_swizzle.hpp>
#include <cute/pointer_flagged.hpp>
//
// Tensor Algorithms
//
#include <cute/algorithm/tensor_algorithms.hpp>
#include <cute/algorithm/fill.hpp>
#include <cute/algorithm/clear.hpp>
#include <cute/algorithm/copy.hpp>
#include <cute/algorithm/prefetch.hpp>
#include <cute/algorithm/axpby.hpp>
#include <cute/algorithm/gemm.hpp>
#include <cute/algorithm/cooperative_copy.hpp>
#include <cute/algorithm/cooperative_gemm.hpp>
|