File size: 5,794 Bytes
d5175d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
/**
 * Copyright (c) Facebook, Inc. and its affiliates.
 * 
 * This source code is licensed under the MIT license found in the
 * LICENSE file in the root directory of this source tree.
 */

#include "dynamicconv_cuda.cuh"
#include "dynamicconv_cuda_forward.cu"
#include "dynamicconv_cuda_backward.cu"
#include "../cuda_utils.cu"

// FS is filter size and kernels are specialized for filter sizes
template<int FS, int SB, int padding_l, typename scalar_t>
__global__
void dynamicconv_forward_kernel(const scalar_t* input,
                                const scalar_t* weight,
                                int minibatch,
                                int sequenceLength,
                                int numFeatures,
                                int numFiltersInBlock,
                                int numHeads,
                                scalar_t* output) {
  assert(blockDim.x == SB);

  const int tid = threadIdx.x;
  const int batchIdx = blockIdx.x;
  const int featureIdx = blockIdx.y;
  const int head = featureIdx / numFiltersInBlock;

  const int IOOffset = batchIdx * numFeatures * sequenceLength
                       + featureIdx * sequenceLength;
  const scalar_t* inputFeature = &input[IOOffset];
  scalar_t* outputFeature = &output[IOOffset];

  scalar_t filter[FS];

  __shared__ scalar_t tempInput[SB + FS];
  zeroSharedMem<FS, SB, padding_l>(tempInput);

  const int numIterations = divUp<int, int>(sequenceLength, SB);

  for (int i = 0; i < numIterations; ++i) {
    __syncthreads();
    const int inputOffset = i * SB;
    load_input_to_shared<FS, SB, padding_l>(inputFeature, inputOffset,
                                            sequenceLength, i,
                                            numIterations, false, tempInput);
    __syncthreads();
    if (inputOffset + tid < sequenceLength) {

      #pragma unroll
      for (int k = 0; k < FS; ++k) {
        const int filterOffset = batchIdx * numHeads * FS * sequenceLength
                                 + head * FS * sequenceLength
                                 + k * sequenceLength
                                 + i * SB + tid;
        filter[k] = weight[filterOffset];
      }

      scalar_t out = scalar_t(0.0);
      #pragma unroll
      for (int k = 0; k < FS; ++k) {
        out += filter[k] * tempInput[tid + k];
      }

      outputFeature[inputOffset + tid] = out;

    }
  }
}

template<int FS, int SB, int padding_l, typename scalar_t>
__global__
void dynamicconv_backward_kernel(
    const scalar_t* gradOutput, // B * C * T
    const scalar_t* input, // B * C * T
    const scalar_t* weight,
    int minibatch,
    int sequenceLength,
    int numFeatures,
    int numFiltersInBlock,
    int numHeads,
    scalar_t* gradWeight,
    scalar_t* gradInput) { // B * H * k * T

  assert(blockDim.x == SB);

  // each block operates on a single batch and filter head
  const int tid = threadIdx.x;
  const int batchIdx = blockIdx.x;
  const int headIdx = blockIdx.y;
  const int chunkIdx = blockIdx.z;

  const int numChunks = divUp<int, int>(sequenceLength, SB);
  const int inputOffset = chunkIdx * SB;

  // initialize shared memory for output gradient and input
  __shared__ scalar_t tempGradOutput[SB + FS];
  __shared__ scalar_t tempInput[SB + FS];
  const int padding = FS - padding_l - 1;

  zeroSharedMem<FS, SB, padding>(tempGradOutput);
  zeroSharedMem<FS, SB, padding_l>(tempInput);

  // initialize local filter and weight gradient sum arrays
  scalar_t tempGradSum[FS];
  scalar_t bfilter[FS];
  for (int k = 0; k < FS; ++k) {
    tempGradSum[k] = scalar_t(0.0);

    int idxOffset = inputOffset + tid + k - padding;
    if (idxOffset >= 0 && idxOffset < sequenceLength) {
      int bfilterOffset = batchIdx * numHeads * FS * sequenceLength
                          + headIdx * FS * sequenceLength
                          + (FS - k  - 1) * sequenceLength
                          + idxOffset;
      bfilter[k] = weight[bfilterOffset];
    } else {
      bfilter[k] = scalar_t(0.0);
    }
  }


  // iterate over filter block
  for (int featureIdx = 0; featureIdx < numFiltersInBlock; ++featureIdx) {
    __syncthreads();

    // load input and output gradient for this channel and chunk
    const int IOOffset = batchIdx * numFeatures * sequenceLength
                         + (headIdx * numFiltersInBlock + featureIdx) * sequenceLength;
    const scalar_t* inputFeature = &input[IOOffset];
    const scalar_t* gradOutputFeature = &gradOutput[IOOffset];
    scalar_t* gradInputFeature = &gradInput[IOOffset];

    load_input_to_shared<FS, SB, padding>(gradOutputFeature, inputOffset,
                                            sequenceLength, chunkIdx,
                                            numChunks, true, tempGradOutput);
    load_input_to_shared<FS, SB, padding_l>(inputFeature, inputOffset,
                                            sequenceLength, chunkIdx,
                                            numChunks, true, tempInput);
    __syncthreads();
 
    // sum input and weight gradients
    scalar_t out = scalar_t(0.0);
    #pragma unroll
    for (int k = 0; k < FS; ++k) {
      tempGradSum[k] += tempInput[tid + k] * tempGradOutput[tid + padding];
      out += bfilter[k] * tempGradOutput[tid + k];
    }
    
    if (inputOffset + tid < sequenceLength) {
      gradInputFeature[inputOffset + tid] = out;
    }
  }

  const int gradOffset = batchIdx * numHeads * FS * sequenceLength
               + headIdx * FS * sequenceLength;
  scalar_t *gradWeightFeature = &gradWeight[gradOffset];

  // write weight gradient
  if (inputOffset + tid < sequenceLength) {
    for (int k = 0; k < FS; ++k) {
      const int outputOffset = k * sequenceLength + inputOffset + tid;
      gradWeightFeature[outputOffset] = tempGradSum[k];
    }
  }
}