// Modified from: cutlass/gemm/collective/builders/sm90_gmma_builder.inl // clang-format off #pragma once #include "cutlass/gemm/collective/builders/sm90_gmma_builder.inl" #include "cutlass_extensions/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp" ///////////////////////////////////////////////////////////////////////////////////////////////// namespace cutlass::gemm::collective { ///////////////////////////////////////////////////////////////////////////////////////////////// // GMMA_TMA_WS_SS (BlockScaled Builders) template < class ElementA, class GmemLayoutATag, int AlignmentA, class ElementB, class GmemLayoutBTag, int AlignmentB, class ElementAccumulator, class TileShape_MNK, class ClusterShape_MNK, class StageCountType, int ScaleGranularityM > struct CollectiveBuilder< arch::Sm90, arch::OpClassTensorOp, ElementA, GmemLayoutATag, AlignmentA, ElementB, GmemLayoutBTag, AlignmentB, ElementAccumulator, TileShape_MNK, ClusterShape_MNK, StageCountType, KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum, cute::enable_if_t< not detail::is_use_rmem_A()> > { using KernelScheduleType = KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum; static_assert(is_static::value); static_assert(is_static::value); #ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED static_assert(cutlass::detail::dependent_false, "Unsupported Toolkit for SM90 Collective Builder\n"); #endif static_assert(detail::is_aligned(), "Should meet TMA alignment requirement\n"); static constexpr bool IsArrayOfPointersGemm = (cute::is_any_of_v); static constexpr bool IsFP8Input = detail::is_input_fp8(); static_assert((!IsFP8Input || !IsArrayOfPointersGemm), "KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum is only compatible with FP8 Blocked Scaled version right now."); // For fp32 types, map to tf32 MMA value type using ElementAMma = cute::conditional_t, tfloat32_t, ElementA>; using ElementBMma = cute::conditional_t, tfloat32_t, ElementB>; static constexpr cute::GMMA::Major GmmaMajorA = detail::gmma_ss_tag_to_major_A(); static constexpr cute::GMMA::Major GmmaMajorB = detail::gmma_ss_tag_to_major_B(); static constexpr bool IsCooperative = cute::is_any_of_v>; using AtomLayoutMNK = cute::conditional_t>, Layout>>; using TiledMma = decltype(cute::make_tiled_mma(cute::GMMA::ss_op_selector< ElementAMma, ElementBMma, ElementAccumulator, TileShape_MNK, GmmaMajorA, GmmaMajorB>(), AtomLayoutMNK{})); using GmemTiledCopyA = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{}))); using GmemTiledCopyB = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape_MNK{}))); using SmemLayoutAtomA = decltype(detail::ss_smem_selector< GmmaMajorA, ElementAMma, decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); using SmemLayoutAtomB = decltype(detail::ss_smem_selector< GmmaMajorB, ElementBMma, decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); static constexpr size_t TensorMapStorage = IsArrayOfPointersGemm ? sizeof(cute::TmaDescriptor) * 2 /* for A and B */ : 0; static constexpr int KernelSmemCarveout = static_cast(TensorMapStorage); static constexpr int PipelineStages = detail::compute_stage_count_or_override(StageCountType{}); using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedBlockScalingSubGroupMFP8; using SmemCopyAtomA = void; using SmemCopyAtomB = void; using CollectiveOp = CollectiveMma< DispatchPolicy, TileShape_MNK, ElementA, TagToStrideA_t, ElementB, TagToStrideB_t, TiledMma, GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity >; }; ///////////////////////////////////////////////////////////////////////////////////////////////// } // namespace cutlass::gemm::collective /////////////////////////////////////////////////////////////////////////////////////////////////