/** * Copyright (c) Facebook, Inc. and its affiliates. * * This source code is licensed under the MIT license found in the * LICENSE file in the root directory of this source tree. */ #include #include #include #include #include #include #include #include #include #include #include #include #define SHFL_MASK 0xffffffff template __global__ void lightconv_forward_kernel(const scalar_t* input, const scalar_t* filters, int minibatch, int sequenceLength, int numFeatures, int numFiltersInBlock, scalar_t* output); template __global__ void lightconv_grad_wrt_input_kernel( const scalar_t* input, const scalar_t* filters, int minibatch, int sequenceLength, int numFeatures, int numFiltersInBlock, scalar_t* output); template __global__ void lightconv_grad_wrt_weights_firstpass_short_kernel( const scalar_t* input, const scalar_t* gradInput, int minibatch, int sequenceLength, int numFeatures, int numFiltersInBlock, int numHeads, float* output); template __global__ void lightconv_grad_wrt_weights_secondpass_short_kernel( const float* input, const int minibatch, const int numFiltersInBlock, scalar_t* output); template __global__ void lightconv_grad_wrt_weights_firstpass_kernel( const scalar_t* input, const scalar_t* gradInput, int minibatch, int sequenceLength, int numFeatures, int numFiltersInBlock, float* output); template __global__ void lightconv_grad_wrt_weights_secondpass_kernel( const float* input, const int minibatch, const int numFiltersInBlock, scalar_t* output);