File size: 2,165 Bytes
d5175d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
/**
 * Copyright (c) Facebook, Inc. and its affiliates.
 * 
 * This source code is licensed under the MIT license found in the
 * LICENSE file in the root directory of this source tree.
 */

#include <ATen/ATen.h>
#include <c10/cuda/CUDAStream.h>

#include <cuda.h>
#include <cuda_runtime.h>

#include <algorithm>
#include <functional>
#include <iostream>
#include <stdexcept>
#include <utility>
#include <vector>

#include <stdlib.h>
#include <assert.h>

#define SHFL_MASK 0xffffffff

template<int FS, int SB, int padding_l, typename scalar_t>
__global__
void lightconv_forward_kernel(const scalar_t* input,
                              const scalar_t* filters,
                              int minibatch, int sequenceLength,
                              int numFeatures, int numFiltersInBlock,
                              scalar_t* output);

template<int FS, int SB, int padding_l, typename scalar_t>
__global__
void lightconv_grad_wrt_input_kernel(
    const scalar_t* input, 
    const scalar_t* filters,
    int minibatch,
    int sequenceLength,
    int numFeatures,
    int numFiltersInBlock,
    scalar_t* output);

template<int FS, int SB, int padding_l, typename scalar_t>
__global__
void lightconv_grad_wrt_weights_firstpass_short_kernel(
    const scalar_t* input,
    const scalar_t* gradInput,
    int minibatch,
    int sequenceLength,
    int numFeatures,
    int numFiltersInBlock,
    int numHeads,
    float* output);

template<int FS, int SB, typename scalar_t>
__global__
void lightconv_grad_wrt_weights_secondpass_short_kernel(
    const float* input,
    const int minibatch, 
    const int numFiltersInBlock,
    scalar_t* output);

template<int FS, int SB, int padding_l, typename scalar_t>
__global__
void lightconv_grad_wrt_weights_firstpass_kernel(
    const scalar_t* input,
    const scalar_t* gradInput,
    int minibatch,
    int sequenceLength,
    int numFeatures,
    int numFiltersInBlock,
    float* output);

template<int FS, int SB, typename scalar_t>
__global__
void lightconv_grad_wrt_weights_secondpass_kernel(
    const float* input,
    const int minibatch, 
    const int numFiltersInBlock,
    scalar_t* output);