// clang-format off | |
// adapted from: https://github.com/soundOfDestiny/cutlass/blob/a4208aa6958864923505cade9c63eb2a6daf16e5/include/cutlass/gemm/collective/fp8_accumulation.hpp | |
/*************************************************************************************************** | |
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
* SPDX-License-Identifier: BSD-3-Clause | |
* | |
* Redistribution and use in source and binary forms, with or without | |
* modification, are permitted provided that the following conditions are met: | |
* | |
* 1. Redistributions of source code must retain the above copyright notice, this | |
* list of conditions and the following disclaimer. | |
* | |
* 2. Redistributions in binary form must reproduce the above copyright notice, | |
* this list of conditions and the following disclaimer in the documentation | |
* and/or other materials provided with the distribution. | |
* | |
* 3. Neither the name of the copyright holder nor the names of its | |
* contributors may be used to endorse or promote products derived from | |
* this software without specific prior written permission. | |
* | |
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE | |
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE | |
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | |
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR | |
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER | |
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, | |
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE | |
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |
* | |
**************************************************************************************************/ | |
////////////////////////////////////////////////////////////////////////////// | |
///////////////////////////////////FP8 Accumulation/////////////////////////// | |
////////////////////////////////////////////////////////////////////////////// | |
/// This class provides API to promote (add) or scale (multiply_add) the results | |
/// from the tensor core accumulators to the main accumulators when the number | |
/// of MMAs reaches the max number of MMA interval specified by user, after that | |
/// the tensor core accumulators are zeroed. | |
////////////////////////////////////////////////////////////////////////////// | |
namespace cutlass::gemm::collective { | |
template < | |
class EngineAccum, | |
class LayoutAccum> | |
struct GmmaFP8AccumulationWithScale { | |
using TensorAccum = cute::Tensor<EngineAccum, LayoutAccum>; | |
using ElementAccumulator = typename EngineAccum::value_type; | |
static_assert(is_static<LayoutAccum>::value, "Accumulator Layout should be static"); | |
static_assert(is_rmem<TensorAccum>::value , "Accumulator tensor must be rmem resident."); | |
private: | |
TensorAccum& accum_; | |
TensorAccum accum_temp_; | |
uint32_t accum_promotion_interval_; // defines the max num of executed MMAs after which accum should be promoted. | |
uint32_t mma_count_per_mainloop_iteration_; // num of MMAs per k_tile of mainloop | |
uint32_t mma_count_; // current executed MMAs | |
uint32_t reset_accum_flag_; // accum needs to be zeroed or not. | |
// promote or `add` the partial accumulators to main accumulator (FADD). | |
CUTLASS_DEVICE | |
void promote_core() { | |
warpgroup_wait<0>(); | |
CUTLASS_PRAGMA_UNROLL | |
for (int i = 0; i < size(accum_); ++i) { | |
accum_(i) += accum_temp_(i); | |
} | |
} | |
// `multiply` scale the partial accumulators and `add` to main accumulator (FFMA). | |
template < | |
class EngineScale, | |
class LayoutScale> | |
CUTLASS_DEVICE | |
void scale_core(const cute::Tensor<EngineScale, LayoutScale> &scale) { | |
using TensorScale = cute::Tensor<EngineScale, LayoutScale>; | |
static_assert(is_static<LayoutScale>::value, "Scale Layout should be static"); | |
static_assert(is_rmem<TensorScale>::value , "Scale tensor must be rmem resident."); | |
static_assert(LayoutAccum{}.shape() == LayoutScale{}.shape(), "Accumulator and scale must have same shape."); | |
warpgroup_wait<0>(); | |
CUTLASS_PRAGMA_UNROLL | |
for (int i = 0; i < size(accum_); ++i) { | |
accum_(i) += accum_temp_(i) * scale(i); | |
} | |
} | |
public: | |
CUTLASS_DEVICE | |
GmmaFP8AccumulationWithScale( | |
TensorAccum &accum, | |
uint32_t accum_promotion_interval, | |
uint32_t mma_count_per_mainloop_iteration) | |
: accum_(accum), | |
accum_promotion_interval_(accum_promotion_interval), | |
mma_count_per_mainloop_iteration_(mma_count_per_mainloop_iteration), | |
mma_count_(0), | |
reset_accum_flag_(0) | |
{ | |
accum_temp_ = cute::make_fragment_like(accum); | |
} | |
// | |
// Methods (Common) | |
// | |
CUTLASS_DEVICE | |
TensorAccum& operator()() { | |
return accum_temp_; | |
} | |
/// prepare the MMA accumulators when initialization or zeroing is required. | |
CUTLASS_DEVICE | |
bool prepare_if_needed() { | |
return reset_accum_flag_; | |
} | |
// | |
// Methods (for FADD version) | |
// | |
/// promote (add) the results from the MMA accumulators to main accumulator if needed. | |
CUTLASS_DEVICE | |
void promote_if_needed() { | |
mma_count_ += mma_count_per_mainloop_iteration_; | |
reset_accum_flag_ = __shfl_sync(0xffffffff, mma_count_ == accum_promotion_interval_, 0); | |
if (reset_accum_flag_) { | |
promote_core(); | |
mma_count_ = 0; | |
} | |
} | |
/// promote (add) the residue results from the MMA accumulators to main accumulator if needed. | |
CUTLASS_DEVICE | |
void promote_residue_if_needed() { | |
if (__shfl_sync(0xffffffff, mma_count_ > 0, 0)) { | |
promote_core(); | |
} | |
} | |
// | |
// Methods (for FFMA version) | |
// | |
/// scale (multiply_add) the results from the MMA accumulators to main accumulator if needed. | |
template < | |
class EngineScale, | |
class LayoutScale> | |
CUTLASS_DEVICE | |
void scale_if_needed(const cute::Tensor<EngineScale, LayoutScale> &scale) { | |
mma_count_ += mma_count_per_mainloop_iteration_; | |
reset_accum_flag_ = __shfl_sync(0xffffffff, mma_count_ == accum_promotion_interval_, 0); | |
if (reset_accum_flag_) { | |
scale_core(scale); | |
mma_count_ = 0; | |
} | |
} | |
/// scale (multiply_add) the residue results from the MMA accumulators to main accumulator if needed. | |
template < | |
class EngineScale, | |
class LayoutScale> | |
CUTLASS_DEVICE | |
void scale_residue_if_needed(const cute::Tensor<EngineScale, LayoutScale> &scale) { | |
if (__shfl_sync(0xffffffff, mma_count_ > 0, 0)) { | |
scale_core(scale); | |
} | |
} | |
}; | |
} // namespace cutlass::gemm::collective | |