Spaces:
Sleeping
Sleeping
/*************************************************************************************************** | |
* Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
* SPDX-License-Identifier: BSD-3-Clause | |
* | |
* Redistribution and use in source and binary forms, with or without | |
* modification, are permitted provided that the following conditions are met: | |
* | |
* 1. Redistributions of source code must retain the above copyright notice, this | |
* list of conditions and the following disclaimer. | |
* | |
* 2. Redistributions in binary form must reproduce the above copyright notice, | |
* this list of conditions and the following disclaimer in the documentation | |
* and/or other materials provided with the distribution. | |
* | |
* 3. Neither the name of the copyright holder nor the names of its | |
* contributors may be used to endorse or promote products derived from | |
* this software without specific prior written permission. | |
* | |
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE | |
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE | |
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | |
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR | |
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER | |
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, | |
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE | |
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |
* | |
**************************************************************************************************/ | |
/*! \file | |
\brief Unit tests for thread-level GEMM | |
*/ | |
#include "../../common/cutlass_unit_test.h" | |
#include "cutlass/layout/layout.h" | |
#include "cutlass/epilogue/thread/activation.h" | |
#include "cutlass/util/host_tensor.h" | |
///////////////////////////////////////////////////////////////////////////////////////////////// | |
template <typename T, int N, typename Func> | |
__global__ void test_Epilogue_thread_activation(T *out, T *in) { | |
cutlass::Array<T, N> *vec_out = reinterpret_cast<cutlass::Array<T, N> *>(out); | |
cutlass::Array<T, N> *vec_in = reinterpret_cast<cutlass::Array<T, N> *>(in); | |
Func func; | |
vec_out[threadIdx.x] = func(vec_in[threadIdx.x]); | |
} | |
///////////////////////////////////////////////////////////////////////////////////////////////// | |
// | |
// Reference | |
// | |
static double GELU_golden_input[] = { | |
1.587425827980, 1.157652974129, 0.750432848930, -0.965980410576, | |
-0.388184845448, 0.014422321692, 0.353164494038, 1.354383468628, | |
0.167588576674, 0.272798538208, -0.377032428980, 1.923444747925, | |
0.308164477348, -0.341318070889, 0.278338819742, -0.292668998241, | |
-1.051743745804, -0.814175724983, 0.112737402320, 1.262938618660, | |
-1.582363605499, 0.722016870975, 1.053453564644, -0.659764587879, | |
0.734917521477, 0.091274201870, 0.604461073875, -0.219043627381, | |
-0.136795744300, 0.960650205612, -1.805408835411, 0.091029644012, | |
-1.023343324661, 0.147713735700, -0.499895423651, 1.351878166199, | |
-1.631091356277, -0.336171895266, -1.612408638000, 0.090832948685, | |
-0.658132910728, -0.326727777719, -1.986387014389, 0.787685871124, | |
-1.015677452087, -0.225094825029, 0.876752018929, 0.744826257229, | |
0.870290279388, -0.757595360279, 1.510331749916, 0.750012576580, | |
0.906444966793, -0.915759027004, 1.260277032852, -0.158465340734, | |
-0.109191477299, -0.817102134228, 0.391305118799, -0.524910449982, | |
0.351349592209, 0.801979541779, 0.446691334248, -0.741077482700, | |
1.205966711044, -0.910210072994, 0.945986449718, 0.784096539021, | |
1.670521497726, 0.344931513071, -0.301411420107, 0.309870749712, | |
-0.879704594612, -1.951189517975, -0.805817663670, -0.661812782288, | |
-0.505914270878, -1.836273789406, -0.381845980883, -0.554707705975, | |
-0.375447630882, -0.516645610332, 0.509586095810, 1.087131023407, | |
2.664817094803, -1.558295488358, -0.076461032033, -0.504621028900, | |
1.327111959457, -1.819981694221, 1.350415468216, -2.074112653732, | |
1.501431345940, -1.339013576508, 0.162817999721, -1.473457217216, | |
0.357770472765, 0.188413277268, 1.601302266121, -0.653882205486, | |
0.856162548065, 0.763102591038, -0.526283502579, 0.581961452961, | |
0.089969776571, 1.968745589256, 0.545802056789, -1.168786048889, | |
1.206663012505, -0.109096683562, -1.223938226700, 0.744599223137, | |
-1.779406785965, 0.766436159611, -0.579044401646, -1.002057313919, | |
-0.715845823288, -0.562508940697, 0.886768460274, 2.327786445618, | |
-0.148763969541, -0.918884515762, -0.367678701878, -1.105021238327, | |
-0.461237311363, 0.158228352666, -0.254040330648, 1.427477598190, | |
0.277530491352, 0.046293262392, -0.535557329655, -1.486695051193, | |
-0.953706681728, -1.040495038033, -0.314667612314, 0.348172843456, | |
0.522773325443, 0.025960063562, -0.482472360134, 1.993084549904, | |
-0.253064930439, -0.012146313675, -2.166327714920, 0.398040622473, | |
-0.022238900885, -0.443580865860, -0.898376941681, -0.571689844131, | |
1.666979670525, -0.831176340580, -0.671057403088, 0.481970995665, | |
-1.096243023872, -1.493894338608, 0.596651911736, -0.229505166411, | |
1.165976166725, 0.905094027519, 0.049716457725, -1.362933635712, | |
-0.366948783398, 1.461613893509, -0.718411505222, 0.895385026932, | |
-0.763122260571, 1.329716682434, 1.366570711136, -0.086544901133, | |
0.059739742428, 0.940766513348, -0.272854357958, -1.738811373711, | |
-0.361239165068, 0.696977972984, 1.288442254066, 1.264815807343, | |
-0.573566436768, -1.141678214073, 0.081865988672, -0.886228799820, | |
-0.236933603883, 1.050115466118, -0.538952171803, 0.651773929596, | |
-0.220034509897, -1.198960781097, 1.247478365898, -0.053529661149, | |
0.639809548855, 1.672434806824, 0.511088073254, -1.179364681244, | |
-0.730427742004, 0.157630980015, 0.389369845390, -0.925578773022, | |
-0.093250080943, -0.391062080860, 0.852983593941, 1.868778109550, | |
-1.198786258698, 0.604997038841, -1.482687234879, -2.469333171844, | |
0.718807697296, -0.559609353542, 2.187228441238, -2.927527904510, | |
0.148535788059, -0.097280368209, 0.674131810665, -1.137645959854, | |
0.792729616165, -1.166317462921, -0.498791724443, 1.675866723061, | |
-0.137909621000, -0.653263568878, -2.281216144562, 0.296096831560, | |
2.002410173416, 1.083609819412, 0.933580815792, -1.504760265350, | |
2.185185909271, 0.286121010780, -1.035485863686, -0.216372340918, | |
-0.274334043264, -0.849510788918, -1.397169828415, -0.407644748688, | |
0.159476816654, -0.170650705695, 0.335193097591, -0.156852483749, | |
0.036168430001, 0.858105242252, -1.086121797562, 0.404813349247, | |
-0.481496721506, -0.389882832766, 0.020690204576, -0.772020936012, | |
-0.758921504021, 0.323482036591, 0.115715265274, -0.811228036880, | |
-0.882436633110, 0.176811277866, 1.678015947342, 0.379081040621, | |
-0.842976212502, 0.346952259541, -0.545828759670, 1.632800459862 | |
}; | |
static double GELU_golden_output[] = { | |
1.498199582100, 1.014679551125, 0.580462038517, -0.161344811320, | |
-0.135453075171, 0.007294139825, 0.225325092673, 1.235459089279, | |
0.094946734607, 0.165724009275, -0.133120641112, 1.871103763580, | |
0.191376730800, -0.125069886446, 0.169681981206, -0.112644664943, | |
-0.154036879539, -0.169163048267, 0.061428427696, 1.132469892502, | |
-0.089851818979, 0.552240371704, 0.899579226971, -0.168043658137, | |
0.565008401871, 0.048956073821, 0.439583092928, -0.090532489121, | |
-0.060955654830, 0.798911273479, -0.064101703465, 0.048816055059, | |
-0.156645998359, 0.082529976964, -0.154254898429, 1.232632875443, | |
-0.083896033466, -0.123835846782, -0.086161509156, 0.048703473061, | |
-0.167972877622, -0.121522113681, -0.046670529991, 0.617986679077, | |
-0.157319813967, -0.092503339052, 0.709896743298, 0.574865520000, | |
0.703132867813, -0.169963955879, 1.411436080933, 0.580042064190, | |
0.741154611111, -0.164741978049, 1.129479527473, -0.069256491959, | |
-0.049848672003, -0.169087052345, 0.255214750767, -0.157380074263, | |
0.223928079009, 0.632535398006, 0.300378054380, -0.169946283102, | |
1.068588852882, -0.165071934462, 0.783203184605, 0.614346146584, | |
1.591325283051, 0.219006344676, -0.115003645420, 0.192637458444, | |
-0.166712537408, -0.049788996577, -0.169361919165, -0.168130636215, | |
-0.155041679740, -0.060888241976, -0.134137839079, -0.160614117980, | |
-0.132782235742, -0.156389534473, 0.354075312614, 0.936574816704, | |
2.654553413391, -0.092845752835, -0.035900454968, -0.154874503613, | |
1.204704761505, -0.062572605908, 1.230982899666, -0.039479542524, | |
1.401402950287, -0.120890334249, 0.091938301921, -0.103604510427, | |
0.228880971670, 0.108285568655, 1.513783097267, -0.167782157660, | |
0.688394129276, 0.593158841133, -0.157540664077, 0.418839782476, | |
0.048209801316, 1.920528769493, 0.386099845171, -0.141709372401, | |
1.069367766380, -0.049809500575, -0.135230198503, 0.574639260769, | |
-0.066881760955, 0.596510827541, -0.162873372436, -0.158483341336, | |
-0.169686436653, -0.161375194788, 0.720409095287, 2.304597616196, | |
-0.065585561097, -0.164551988244, -0.131098195910, -0.148708447814, | |
-0.148663327098, 0.089060656726, -0.101548098028, 1.317959904671, | |
0.169103100896, 0.024001283571, -0.158595800400, -0.101909510791, | |
-0.162240833044, -0.155090972781, -0.118474565446, 0.221488356590, | |
0.365645468235, 0.013248858973, -0.151851043105, 1.946992278099, | |
-0.101253561676, -0.006014300976, -0.032804865390, 0.260597169399, | |
-0.010922161862, -0.145792976022, -0.165743649006, -0.162226170301, | |
1.587365984917, -0.168676435947, -0.168497130275, 0.330191940069, | |
-0.149622067809, -0.100989677012, 0.432351946831, -0.093922272325, | |
1.023946166039, 0.739726305008, 0.025843897834, -0.117827951908, | |
-0.130937814713, 1.356489539146, -0.169726014137, 0.729478538036, | |
-0.169943705201, 1.207641005516, 1.249209761620, -0.040288090706, | |
0.031292784959, 0.777626037598, -0.107090584934, -0.071350336075, | |
-0.129670530558, 0.527676224709, 1.161149263382, 1.134579420090, | |
-0.162394225597, -0.144757837057, 0.043603736907, -0.166386902332, | |
-0.096278958023, 0.895924389362, -0.158969298005, 0.484089732170, | |
-0.090857118368, -0.138206124306, 1.115107178688, -0.025622237474, | |
0.472724437714, 1.593463659286, 0.355387806892, -0.140493586659, | |
-0.169871479273, 0.088687323034, 0.253673940897, -0.164135158062, | |
-0.043161027133, -0.136040985584, 0.685087263584, 1.811169505119, | |
-0.138226687908, 0.440080583096, -0.102422207594, -0.016713079065, | |
0.549075841904, -0.161096408963, 2.155813455582, -0.005001218989, | |
0.083037458360, -0.044870752841, 0.505522191525, -0.145202502608, | |
0.623111069202, -0.141991063952, -0.154108211398, 1.597298502922, | |
-0.061391282827, -0.167753636837, -0.025704355910, 0.182520583272, | |
1.957115054131, 0.932696640491, 0.769961357117, -0.099604383111, | |
2.153636932373, 0.175279796124, -0.155551761389, -0.089653611183, | |
-0.107515335083, -0.168032020330, -0.113423995674, -0.139319628477, | |
0.089841812849, -0.073763631284, 0.211594089866, -0.068651281297, | |
0.018605981022, 0.690416753292, -0.150658726692, 0.266040354967, | |
-0.151710823178, -0.135800719261, 0.010515870526, -0.169883996248, | |
-0.169960290194, 0.202769815922, 0.063187584281, -0.169236257672, | |
-0.166577890515, 0.100812792778, 1.599699616432, 0.245525524020, | |
-0.168275654316, 0.220552831888, -0.159705042839, 1.549110531807 | |
}; | |
///////////////////////////////////////////////////////////////////////////////////////////////// | |
TEST(Epilogue_thread_gelu_taylor, device_f32) { | |
int const kN = 256; | |
int const kV = 4; | |
using Element = float; | |
using Func = cutlass::epilogue::thread::GELU_taylor<cutlass::Array<Element, kV>>; | |
double tolerance = 0.005; | |
// | |
// Construct workspace | |
// | |
cutlass::HostTensor<Element, cutlass::layout::RowMajor> tensor_Destination({1, kN}); | |
cutlass::HostTensor<Element, cutlass::layout::RowMajor> tensor_Source({1, kN}); | |
for (int i = 0; i < kN; ++i) { | |
tensor_Source.host_data(i) = Element(GELU_golden_input[i]); | |
} | |
tensor_Destination.sync_device(); | |
tensor_Source.sync_device(); | |
// | |
// Launch the kernel | |
// | |
dim3 grid(1,1,1); | |
dim3 block(kN / kV, 1, 1); | |
test_Epilogue_thread_activation<Element, kV, Func><<< grid, block >>>( | |
tensor_Destination.device_data(), | |
tensor_Source.device_data()); | |
tensor_Destination.sync_host(); | |
// | |
// Verify | |
// | |
for (int i = 0; i < kN; ++i) { | |
Element input = Element(GELU_golden_input[i]); | |
Element got = tensor_Destination.host_data(i); | |
Element expected = Element(GELU_golden_output[i]); | |
double rel_error = (double(got) - double(expected)) / double(expected); | |
double tolerance_override = tolerance; | |
switch (i) { | |
case 142: tolerance_override = 0.008; break; | |
case 203: tolerance_override = 0.03; break; | |
case 207: tolerance_override = 0.09; break; | |
case 218: tolerance_override = 0.013; break; | |
} | |
EXPECT_LT(std::abs(rel_error), tolerance_override) | |
<< "Input[" << i << "]: " << input << ", Got: " << got << ", expected: " << expected; | |
} | |
} | |
///////////////////////////////////////////////////////////////////////////////////////////////// | |
TEST(Epilogue_thread_gelu_taylor, device_f16) { | |
int const kN = 256; | |
int const kV = 8; | |
using Element = cutlass::half_t; | |
using Func = cutlass::epilogue::thread::GELU_taylor<cutlass::Array<Element, kV>>; | |
double tolerance = 0.005; | |
// | |
// Construct workspace | |
// | |
cutlass::HostTensor<Element, cutlass::layout::RowMajor> tensor_Destination({1, kN}); | |
cutlass::HostTensor<Element, cutlass::layout::RowMajor> tensor_Source({1, kN}); | |
for (int i = 0; i < kN; ++i) { | |
tensor_Source.host_data(i) = Element(GELU_golden_input[i]); | |
} | |
tensor_Destination.sync_device(); | |
tensor_Source.sync_device(); | |
// | |
// Launch the kernel | |
// | |
dim3 grid(1,1,1); | |
dim3 block(kN / kV, 1, 1); | |
test_Epilogue_thread_activation<Element, kV, Func><<< grid, block >>>( | |
tensor_Destination.device_data(), | |
tensor_Source.device_data()); | |
tensor_Destination.sync_host(); | |
// | |
// Verify | |
// | |
for (int i = 0; i < kN; ++i) { | |
Element input = Element(GELU_golden_input[i]); | |
Element got = tensor_Destination.host_data(i); | |
Element expected = Element(GELU_golden_output[i]); | |
double rel_error = (double(got) - double(expected)) / double(expected); | |
double tolerance_override = tolerance; | |
switch (i) { | |
case 36: tolerance_override = 0.006; break; | |
case 77: tolerance_override = 0.009; break; | |
case 95: tolerance_override = 0.008; break; | |
case 112: tolerance_override = 0.007; break; | |
case 171: tolerance_override = 0.006; break; | |
case 203: tolerance_override = 0.03; break; | |
case 207: tolerance_override = 0.15; break; | |
} | |
EXPECT_LT(std::abs(rel_error), tolerance_override) | |
<< "Input[" << i << "]: " << input << ", Got: " << got << ", expected: " << expected; | |
} | |
} | |
///////////////////////////////////////////////////////////////////////////////////////////////// | |