File size: 10,609 Bytes
d5175d3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 |
/**
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
#include "lightconv_cuda.cuh"
#include "lightconv_cuda_forward.cu"
#include "lightconv_cuda_backward.cu"
#include "../cuda_utils.cu"
template<int FS, int SB, int padding_l, typename scalar_t>
__global__
void lightconv_forward_kernel(const scalar_t* input,
const scalar_t* filters,
int minibatch, int sequenceLength,
int numFeatures, int numFiltersInBlock,
scalar_t* output) {
const int tid = threadIdx.x;
const int batchIdx = blockIdx.x;
const int featureIdx = blockIdx.y;
const int filterIdx = featureIdx / numFiltersInBlock;
const int IOOffset = numFeatures * sequenceLength * batchIdx + featureIdx * sequenceLength;
const scalar_t* inputFeature = &input[IOOffset];
scalar_t* outputFeature = &output[IOOffset];
const scalar_t* inputFilter = &filters[filterIdx * FS];
assert(blockDim.x == SB);
scalar_t filter[FS];
#pragma unroll
for (int i = 0; i < FS; ++i) {
filter[i] = inputFilter[i];
}
__shared__ scalar_t temp[SB + FS];
zeroSharedMem<FS, SB, padding_l>(temp);
const int numIterations = divUp<int, int>(sequenceLength, SB);
for (int i = 0; i < numIterations; ++i) {
// Read input into shared memory
const int inputOffset = i * SB;
load_input_to_shared<FS, SB, padding_l>(inputFeature, inputOffset, sequenceLength,
i, numIterations, (numIterations == 1), temp);
__syncthreads();
scalar_t out = 0;
#pragma unroll
for (int j = 0; j < FS; ++j) {
out += filter[j] * temp[tid + j];
}
// Write output
const int outputOffset = inputOffset;
if ((outputOffset + tid) < sequenceLength) {
outputFeature[outputOffset + tid] = out;
}
__syncthreads();
}
}
template<int FS, int SB, int padding_l, typename scalar_t>
__global__
void lightconv_grad_wrt_input_kernel(
const scalar_t* input,
const scalar_t* filters,
int minibatch,
int sequenceLength,
int numFeatures,
int numFiltersInBlock,
scalar_t* output) {
// input grad kernel is similar to forward kernel
const int tid = threadIdx.x;
const int batchIdx = blockIdx.x;
const int featureIdx = blockIdx.y;
const int filterIdx = featureIdx / numFiltersInBlock;
const int IOOffset = numFeatures * sequenceLength * batchIdx + featureIdx * sequenceLength;
const scalar_t* inputFeature = &input[IOOffset];
scalar_t* outputFeature = &output[IOOffset];
const scalar_t* inputFilter = &filters[filterIdx * FS];
assert(blockDim.x == SB);
scalar_t filter[FS];
// The only change is loading the filter in reverse
#pragma unroll
for (int i = 0; i < FS; ++i) {
filter[i] = inputFilter[FS - i - 1];
}
__shared__ scalar_t temp[SB + FS];
const int padding = FS - padding_l - 1;
zeroSharedMem<FS, SB, padding>(temp);
__syncthreads();
const int numIterations = divUp<int, int>(sequenceLength, SB);
for (int i = 0; i < numIterations; ++i) {
// Read input into shared memory
const int inputOffset = i * SB;
load_input_to_shared<FS, SB, padding>(inputFeature, inputOffset, sequenceLength,
i, numIterations, false, temp);
__syncthreads();
scalar_t out = 0;
#pragma unroll
for (int j = 0; j < FS; ++j) {
out += filter[j] * temp[tid + j];
}
// Write output
const int outputOffset = inputOffset;
if ((outputOffset + tid) < sequenceLength) {
outputFeature[outputOffset + tid] = out;
}
__syncthreads();
}
}
// This is by far the most expensive kernel in terms of time taken.
// Can be 16x slower than the forward or grad_wrt_input when filter size is 31
template<int FS, int SB, int padding_l, typename scalar_t>
__global__
void lightconv_grad_wrt_weights_firstpass_short_kernel(
const scalar_t* input,
const scalar_t* gradInput,
int minibatch,
int sequenceLength,
int numFeatures,
int numFiltersInBlock,
int numHeads,
float* output) {
const int tid = threadIdx.x;
const int batchIdx = blockIdx.x;
const int filterIdx = blockIdx.y;
const int numIterations = divUp<int, int>(sequenceLength, SB);
float* tempOutputGradWeight = &output[filterIdx * FS * minibatch];
assert(blockDim.x == SB);
__shared__ scalar_t tempInput[SB + FS];
__shared__ scalar_t tempGradInput[SB + FS];
// local weight accumulation
float accumWeights[FS];
// Initialize memory
for (int i = 0; i < FS; ++i) {
accumWeights[i] = float(0.0);
}
// loop over each sequence within filterblock
for (int idxInFilterBlock = 0; idxInFilterBlock < numFiltersInBlock; ++idxInFilterBlock) {
const int featureOffset = batchIdx * numFeatures * sequenceLength + (filterIdx * numFiltersInBlock + idxInFilterBlock) * sequenceLength;
const scalar_t* inputFeature = &input[featureOffset];
const scalar_t* gradInputFeature = &gradInput[featureOffset];
zeroSharedMem<FS, SB, padding_l>(tempInput);
zeroSharedMem<FS, SB, (FS/2)>(tempGradInput);
__syncthreads();
for (int i = 0; i < numIterations; ++i) {
const int inputOffset = i * SB;
load_input_to_shared<FS, SB, padding_l>(inputFeature, inputOffset, sequenceLength,
i, numIterations, false, tempInput);
load_input_to_shared<FS, SB, (FS/2)>(gradInputFeature, inputOffset, sequenceLength,
i, numIterations, false, tempGradInput);
__syncthreads();
const int gradIndex = (FS/2) + tid;
scalar_t tempGrad = tempGradInput[gradIndex];
#pragma unroll
for (int j = 0; j < FS; j++) {
const int inputIndex = tid + j;
accumWeights[j] += tempInput[inputIndex] * tempGrad;
}
__syncthreads();
}
}
// Row-major sum
for (int filterWeightIdx = 0; filterWeightIdx < FS; ++filterWeightIdx) {
float temp;
if (tid < sequenceLength) {
temp = accumWeights[filterWeightIdx];
} else {
temp = float(0.0);
}
const int outputOffset = filterWeightIdx * minibatch + batchIdx;
temp = blockReduce(temp);
if (tid == 0) {
tempOutputGradWeight[outputOffset] = temp;
}
}
}
template<int FS, int SB, typename scalar_t>
__global__
void lightconv_grad_wrt_weights_secondpass_short_kernel(
const float* input,
const int minibatch,
const int numFiltersInBlock,
scalar_t* output) {
assert(blockDim.x == SB);
const int tid = threadIdx.x;
const int filterIdx = blockIdx.x;
const int filterWeightIdx = blockIdx.y;
const int inputOffset = filterIdx * FS * minibatch +
filterWeightIdx * minibatch;
const float* tempInput = &input[inputOffset];
// read into shared memory for reduction
int readIndex = tid;
float sum = 0.0;
while (readIndex < minibatch) {
sum += tempInput[readIndex];
readIndex += SB;
}
float temp = blockReduce(sum);
if (tid == 0) {
output[blockIdx.x * FS + blockIdx.y] = temp;
}
}
// This is by far the most expensive kernel in terms of time taken.
// Can be 16x slower than the forward or grad_wrt_input when filter size is 31
template<int FS, int SB, int padding_l, typename scalar_t>
__global__
void lightconv_grad_wrt_weights_firstpass_kernel(
const scalar_t* input,
const scalar_t* gradInput,
int minibatch,
int sequenceLength,
int numFeatures,
int numFiltersInBlock,
float* output) {
assert(blockDim.x == SB);
const int tid = threadIdx.x;
const int batchIdx = blockIdx.x;
const int featureIdx = blockIdx.y;
const int filterIdx = featureIdx / numFiltersInBlock;
const int idxInFilterBlock = featureIdx % numFiltersInBlock;
const int numIterations = divUp<int, int>(sequenceLength, SB);
float temp;
__shared__ scalar_t tempInput[SB + FS];
__shared__ scalar_t tempGradInput[SB + FS];
zeroSharedMem<FS, SB, padding_l>(tempInput);
zeroSharedMem<FS, SB, (FS/2)>(tempGradInput);
__syncthreads();
float accumWeights[FS];
for (int i = 0; i < FS; ++i) {
accumWeights[i] = float(0.0);
}
const int IOOffset = batchIdx * numFeatures * sequenceLength + featureIdx * sequenceLength;
const scalar_t* inputFeature = &input[IOOffset];
const scalar_t* gradInputFeature = &gradInput[IOOffset];
float* tempOutputGradWeight = &output[filterIdx * FS * minibatch * numFiltersInBlock];
for (int i = 0; i < numIterations; ++i) {
const int inputOffset = i * SB;
load_input_to_shared<FS, SB, padding_l>(inputFeature, inputOffset, sequenceLength,
i, numIterations, false, tempInput);
load_input_to_shared<FS, SB, (FS/2)>(gradInputFeature, inputOffset, sequenceLength,
i, numIterations, false, tempGradInput);
__syncthreads();
#pragma unroll
for (int j = 0; j < FS; ++j) {
accumWeights[j] += tempInput[tid + j] * tempGradInput[tid + (FS/2)];
}
__syncthreads();
}
// Row-major sum
for (int filterWeightIdx = 0; filterWeightIdx < FS; ++filterWeightIdx) {
// Write to shared memory before reduction
if (tid < sequenceLength) {
temp = accumWeights[filterWeightIdx];
} else {
temp = float(0.0);
}
temp = blockReduce(temp);
const int outputOffset = filterWeightIdx * minibatch * numFiltersInBlock +
batchIdx * numFiltersInBlock +
idxInFilterBlock;
if (tid == 0) {
tempOutputGradWeight[outputOffset] = temp;
}
}
}
template<int FS, int SB, typename scalar_t>
__global__
void lightconv_grad_wrt_weights_secondpass_kernel(
const float* input,
const int minibatch,
const int numFiltersInBlock,
scalar_t* output) {
assert(blockDim.x == SB);
const int tid = threadIdx.x;
// What is the id within a minibatch
const int filterIdx = blockIdx.x;
const int filterWeightIdx = blockIdx.y;
const int inputOffset = filterIdx * FS * minibatch * numFiltersInBlock +
filterWeightIdx * minibatch * numFiltersInBlock;
const float* tempInput = &input[inputOffset];
int readIndex = tid;
float sum = float(0.0);
while (readIndex < (minibatch * numFiltersInBlock)) {
sum += tempInput[readIndex];
readIndex += SB;
}
float temp = blockReduce(sum);
if (tid == 0) {
output[blockIdx.x * FS + blockIdx.y] = temp;
}
}
|