Spaces:
Running
on
Zero
Running
on
Zero
/* | |
* Copyright (c) Meta Platforms, Inc. and affiliates. | |
* All rights reserved. | |
* | |
* This source code is licensed under the BSD-style license found in the | |
* LICENSE file in the root directory of this source tree. | |
*/ | |
// This file provides utilities for dispatching to specialized versions of | |
// functions. This is especially useful for CUDA kernels, since specializing | |
// them to particular input sizes can often allow the compiler to unroll loops | |
// and place arrays into registers, which can give huge performance speedups. | |
// | |
// As an example, suppose we have the following function which is specialized | |
// based on a compile-time int64_t value: | |
// | |
// template<typename T, int64_t x> | |
// struct SquareOffset { | |
// static void run(T y) { | |
// T val = x * x + y; | |
// std::cout << val << std::endl; | |
// } | |
// } | |
// | |
// This function takes one compile-time argument x, and one run-time argument y. | |
// We might want to compile specialized versions of this for x=0, x=1, etc and | |
// then dispatch to the correct one based on the runtime value of x. | |
// One simple way to achieve this is with a lookup table: | |
// | |
// template<typename T> | |
// void DispatchSquareOffset(const int64_t x, T y) { | |
// if (x == 0) { | |
// SquareOffset<T, 0>::run(y); | |
// } else if (x == 1) { | |
// SquareOffset<T, 1>::run(y); | |
// } else if (x == 2) { | |
// SquareOffset<T, 2>::run(y); | |
// } | |
// } | |
// | |
// This function takes both x and y as run-time arguments, and dispatches to | |
// different specialized versions of SquareOffset based on the run-time value | |
// of x. This works, but it's tedious and error-prone. If we want to change the | |
// set of x values for which we provide compile-time specializations, then we | |
// will need to do a lot of tedius editing of the dispatch function. Also, if we | |
// want to provide compile-time specializations for another function other than | |
// SquareOffset, we will need to duplicate the entire lookup table. | |
// | |
// To solve these problems, we can use the DispatchKernel1D function provided by | |
// this file instead: | |
// | |
// template<typename T> | |
// void DispatchSquareOffset(const int64_t x, T y) { | |
// constexpr int64_t xmin = 0; | |
// constexpr int64_t xmax = 2; | |
// DispatchKernel1D<SquareOffset, T, xmin, xmax>(x, y); | |
// } | |
// | |
// DispatchKernel1D uses template metaprogramming to compile specialized | |
// versions of SquareOffset for all values of x with xmin <= x <= xmax, and | |
// then dispatches to the correct one based on the run-time value of x. If we | |
// want to change the range of x values for which SquareOffset is specialized | |
// at compile-time, then all we have to do is change the values of the | |
// compile-time constants xmin and xmax. | |
// | |
// This file also allows us to similarly dispatch functions that depend on two | |
// compile-time int64_t values, using the DispatchKernel2D function like this: | |
// | |
// template<typename T, int64_t x, int64_t y> | |
// struct Sum { | |
// static void run(T z, T w) { | |
// T val = x + y + z + w; | |
// std::cout << val << std::endl; | |
// } | |
// } | |
// | |
// template<typename T> | |
// void DispatchSum(const int64_t x, const int64_t y, int z, int w) { | |
// constexpr int64_t xmin = 1; | |
// constexpr int64_t xmax = 3; | |
// constexpr int64_t ymin = 2; | |
// constexpr int64_t ymax = 5; | |
// DispatchKernel2D<Sum, T, xmin, xmax, ymin, ymax>(x, y, z, w); | |
// } | |
// | |
// Like its 1D counterpart, DispatchKernel2D uses template metaprogramming to | |
// compile specialized versions of sum for all values of (x, y) with | |
// xmin <= x <= xmax and ymin <= y <= ymax, then dispatches to the correct | |
// specialized version based on the runtime values of x and y. | |
// Define some helper structs in an anonymous namespace. | |
namespace { | |
// 1D dispatch: general case. | |
// Kernel is the function we want to dispatch to; it should take a typename and | |
// an int64_t as template args, and it should define a static void function | |
// run which takes any number of arguments of any type. | |
// In order to dispatch, we will take an additional template argument curN, | |
// and increment it via template recursion until it is equal to the run-time | |
// argument N. | |
template < | |
template <typename, int64_t> | |
class Kernel, | |
typename T, | |
int64_t minN, | |
int64_t maxN, | |
int64_t curN, | |
typename... Args> | |
struct DispatchKernelHelper1D { | |
static void run(const int64_t N, Args... args) { | |
if (curN == N) { | |
// The compile-time value curN is equal to the run-time value N, so we | |
// can dispatch to the run method of the Kernel. | |
Kernel<T, curN>::run(args...); | |
} else if (curN < N) { | |
// Increment curN via template recursion | |
DispatchKernelHelper1D<Kernel, T, minN, maxN, curN + 1, Args...>::run( | |
N, args...); | |
} | |
// We shouldn't get here -- throw an error? | |
} | |
}; | |
// 1D dispatch: Specialization when curN == maxN | |
// We need this base case to avoid infinite template recursion. | |
template < | |
template <typename, int64_t> | |
class Kernel, | |
typename T, | |
int64_t minN, | |
int64_t maxN, | |
typename... Args> | |
struct DispatchKernelHelper1D<Kernel, T, minN, maxN, maxN, Args...> { | |
static void run(const int64_t N, Args... args) { | |
if (N == maxN) { | |
Kernel<T, maxN>::run(args...); | |
} | |
// We shouldn't get here -- throw an error? | |
} | |
}; | |
// 2D dispatch, general case. | |
// This is similar to the 1D case: we take additional template args curN and | |
// curM, and increment them via template recursion until they are equal to | |
// the run-time values of N and M, at which point we dispatch to the run | |
// method of the kernel. | |
template < | |
template <typename, int64_t, int64_t> | |
class Kernel, | |
typename T, | |
int64_t minN, | |
int64_t maxN, | |
int64_t curN, | |
int64_t minM, | |
int64_t maxM, | |
int64_t curM, | |
typename... Args> | |
struct DispatchKernelHelper2D { | |
static void run(const int64_t N, const int64_t M, Args... args) { | |
if (curN == N && curM == M) { | |
Kernel<T, curN, curM>::run(args...); | |
} else if (curN < N && curM < M) { | |
// Increment both curN and curM. This isn't strictly necessary; we could | |
// just increment one or the other at each step. But this helps to cut | |
// on the number of recursive calls we make. | |
DispatchKernelHelper2D< | |
Kernel, | |
T, | |
minN, | |
maxN, | |
curN + 1, | |
minM, | |
maxM, | |
curM + 1, | |
Args...>::run(N, M, args...); | |
} else if (curN < N) { | |
// Increment curN only | |
DispatchKernelHelper2D< | |
Kernel, | |
T, | |
minN, | |
maxN, | |
curN + 1, | |
minM, | |
maxM, | |
curM, | |
Args...>::run(N, M, args...); | |
} else if (curM < M) { | |
// Increment curM only | |
DispatchKernelHelper2D< | |
Kernel, | |
T, | |
minN, | |
maxN, | |
curN, | |
minM, | |
maxM, | |
curM + 1, | |
Args...>::run(N, M, args...); | |
} | |
} | |
}; | |
// 2D dispatch, specialization for curN == maxN | |
template < | |
template <typename, int64_t, int64_t> | |
class Kernel, | |
typename T, | |
int64_t minN, | |
int64_t maxN, | |
int64_t minM, | |
int64_t maxM, | |
int64_t curM, | |
typename... Args> | |
struct DispatchKernelHelper2D< | |
Kernel, | |
T, | |
minN, | |
maxN, | |
maxN, | |
minM, | |
maxM, | |
curM, | |
Args...> { | |
static void run(const int64_t N, const int64_t M, Args... args) { | |
if (maxN == N && curM == M) { | |
Kernel<T, maxN, curM>::run(args...); | |
} else if (curM < maxM) { | |
DispatchKernelHelper2D< | |
Kernel, | |
T, | |
minN, | |
maxN, | |
maxN, | |
minM, | |
maxM, | |
curM + 1, | |
Args...>::run(N, M, args...); | |
} | |
// We should not get here -- throw an error? | |
} | |
}; | |
// 2D dispatch, specialization for curM == maxM | |
template < | |
template <typename, int64_t, int64_t> | |
class Kernel, | |
typename T, | |
int64_t minN, | |
int64_t maxN, | |
int64_t curN, | |
int64_t minM, | |
int64_t maxM, | |
typename... Args> | |
struct DispatchKernelHelper2D< | |
Kernel, | |
T, | |
minN, | |
maxN, | |
curN, | |
minM, | |
maxM, | |
maxM, | |
Args...> { | |
static void run(const int64_t N, const int64_t M, Args... args) { | |
if (curN == N && maxM == M) { | |
Kernel<T, curN, maxM>::run(args...); | |
} else if (curN < maxN) { | |
DispatchKernelHelper2D< | |
Kernel, | |
T, | |
minN, | |
maxN, | |
curN + 1, | |
minM, | |
maxM, | |
maxM, | |
Args...>::run(N, M, args...); | |
} | |
// We should not get here -- throw an error? | |
} | |
}; | |
// 2D dispatch, specialization for curN == maxN, curM == maxM | |
template < | |
template <typename, int64_t, int64_t> | |
class Kernel, | |
typename T, | |
int64_t minN, | |
int64_t maxN, | |
int64_t minM, | |
int64_t maxM, | |
typename... Args> | |
struct DispatchKernelHelper2D< | |
Kernel, | |
T, | |
minN, | |
maxN, | |
maxN, | |
minM, | |
maxM, | |
maxM, | |
Args...> { | |
static void run(const int64_t N, const int64_t M, Args... args) { | |
if (maxN == N && maxM == M) { | |
Kernel<T, maxN, maxM>::run(args...); | |
} | |
// We should not get here -- throw an error? | |
} | |
}; | |
} // namespace | |
// This is the function we expect users to call to dispatch to 1D functions | |
template < | |
template <typename, int64_t> | |
class Kernel, | |
typename T, | |
int64_t minN, | |
int64_t maxN, | |
typename... Args> | |
void DispatchKernel1D(const int64_t N, Args... args) { | |
if (minN <= N && N <= maxN) { | |
// Kick off the template recursion by calling the Helper with curN = minN | |
DispatchKernelHelper1D<Kernel, T, minN, maxN, minN, Args...>::run( | |
N, args...); | |
} | |
// Maybe throw an error if we tried to dispatch outside the allowed range? | |
} | |
// This is the function we expect users to call to dispatch to 2D functions | |
template < | |
template <typename, int64_t, int64_t> | |
class Kernel, | |
typename T, | |
int64_t minN, | |
int64_t maxN, | |
int64_t minM, | |
int64_t maxM, | |
typename... Args> | |
void DispatchKernel2D(const int64_t N, const int64_t M, Args... args) { | |
if (minN <= N && N <= maxN && minM <= M && M <= maxM) { | |
// Kick off the template recursion by calling the Helper with curN = minN | |
// and curM = minM | |
DispatchKernelHelper2D< | |
Kernel, | |
T, | |
minN, | |
maxN, | |
minN, | |
minM, | |
maxM, | |
minM, | |
Args...>::run(N, M, args...); | |
} | |
// Maybe throw an error if we tried to dispatch outside the specified range? | |
} |