danieldk HF staff commited on
Commit
cae2c48
·
0 Parent(s):

Add deformable_detr

Browse files
.gitattributes ADDED
@@ -0,0 +1 @@
 
 
1
+ *.so filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ tags:
4
+ - kernel
5
+ ---
6
+
build.toml ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [general]
2
+ name = "deformable_detr"
3
+
4
+ [torch]
5
+ src = [
6
+ "torch-ext/torch_binding.cpp",
7
+ "torch-ext/torch_binding.h"
8
+ ]
9
+
10
+ [kernel.activation]
11
+ cuda-capabilities = [ "7.0", "7.2", "7.5", "8.0", "8.6", "8.7", "8.9", "9.0" ]
12
+ src = [
13
+ "deformable_detr/ms_deform_attn_cuda.cu",
14
+ "deformable_detr/ms_deform_im2col_cuda.cuh",
15
+ "deformable_detr/ms_deform_attn_cuda.cuh",
16
+ "deformable_detr/ms_deform_attn_cuda.h",
17
+ ]
18
+ include = ["."]
19
+ depends = [ "torch" ]
deformable_detr/ms_deform_attn_cuda.cu ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*!
2
+ **************************************************************************************************
3
+ * Deformable DETR
4
+ * Copyright (c) 2020 SenseTime. All Rights Reserved.
5
+ * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ **************************************************************************************************
7
+ * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8
+ **************************************************************************************************
9
+ */
10
+
11
+ #include <vector>
12
+ #include "deformable_detr/ms_deform_im2col_cuda.cuh"
13
+
14
+ #include <ATen/ATen.h>
15
+ #include <ATen/cuda/CUDAContext.h>
16
+ #include <cuda.h>
17
+ #include <cuda_runtime.h>
18
+
19
+ #include <torch/all.h>
20
+
21
+
22
+ at::Tensor ms_deform_attn_cuda_forward(
23
+ const at::Tensor &value,
24
+ const at::Tensor &spatial_shapes,
25
+ const at::Tensor &level_start_index,
26
+ const at::Tensor &sampling_loc,
27
+ const at::Tensor &attn_weight,
28
+ const int64_t im2col_step)
29
+ {
30
+ at::DeviceGuard guard(value.device());
31
+
32
+ AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
33
+ AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
34
+ AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
35
+ AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
36
+ AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
37
+
38
+ AT_ASSERTM(value.is_cuda(), "value must be a CUDA tensor");
39
+ AT_ASSERTM(spatial_shapes.is_cuda(), "spatial_shapes must be a CUDA tensor");
40
+ AT_ASSERTM(level_start_index.is_cuda(), "level_start_index must be a CUDA tensor");
41
+ AT_ASSERTM(sampling_loc.is_cuda(), "sampling_loc must be a CUDA tensor");
42
+ AT_ASSERTM(attn_weight.is_cuda(), "attn_weight must be a CUDA tensor");
43
+
44
+ const int batch = value.size(0);
45
+ const int spatial_size = value.size(1);
46
+ const int num_heads = value.size(2);
47
+ const int channels = value.size(3);
48
+
49
+ const int num_levels = spatial_shapes.size(0);
50
+
51
+ const int num_query = sampling_loc.size(1);
52
+ const int num_point = sampling_loc.size(4);
53
+
54
+ const int im2col_step_ = std::min(batch, static_cast<int>(im2col_step));
55
+
56
+ AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
57
+
58
+ auto output = at::zeros({batch, num_query, num_heads, channels}, value.options());
59
+
60
+ const int batch_n = im2col_step_;
61
+ auto output_n = output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
62
+ auto per_value_size = spatial_size * num_heads * channels;
63
+ auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
64
+ auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
65
+ for (int n = 0; n < batch/im2col_step_; ++n)
66
+ {
67
+ auto columns = output_n.select(0, n);
68
+ AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, value.scalar_type(), "ms_deform_attn_forward_cuda", ([&] {
69
+ ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(),
70
+ value.data_ptr<scalar_t>() + n * im2col_step_ * per_value_size,
71
+ spatial_shapes.data_ptr<int64_t>(),
72
+ level_start_index.data_ptr<int64_t>(),
73
+ sampling_loc.data_ptr<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
74
+ attn_weight.data_ptr<scalar_t>() + n * im2col_step_ * per_attn_weight_size,
75
+ batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
76
+ columns.data_ptr<scalar_t>());
77
+
78
+ }));
79
+ }
80
+
81
+ output = output.view({batch, num_query, num_heads*channels});
82
+
83
+ return output;
84
+ }
85
+
86
+
87
+ std::vector<at::Tensor> ms_deform_attn_cuda_backward(
88
+ const at::Tensor &value,
89
+ const at::Tensor &spatial_shapes,
90
+ const at::Tensor &level_start_index,
91
+ const at::Tensor &sampling_loc,
92
+ const at::Tensor &attn_weight,
93
+ const at::Tensor &grad_output,
94
+ const int64_t im2col_step)
95
+ {
96
+ at::DeviceGuard guard(value.device());
97
+
98
+ AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
99
+ AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
100
+ AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
101
+ AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
102
+ AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
103
+ AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous");
104
+
105
+ AT_ASSERTM(value.is_cuda(), "value must be a CUDA tensor");
106
+ AT_ASSERTM(spatial_shapes.is_cuda(), "spatial_shapes must be a CUDA tensor");
107
+ AT_ASSERTM(level_start_index.is_cuda(), "level_start_index must be a CUDA tensor");
108
+ AT_ASSERTM(sampling_loc.is_cuda(), "sampling_loc must be a CUDA tensor");
109
+ AT_ASSERTM(attn_weight.is_cuda(), "attn_weight must be a CUDA tensor");
110
+ AT_ASSERTM(grad_output.is_cuda(), "grad_output must be a CUDA tensor");
111
+
112
+ const int batch = value.size(0);
113
+ const int spatial_size = value.size(1);
114
+ const int num_heads = value.size(2);
115
+ const int channels = value.size(3);
116
+
117
+ const int num_levels = spatial_shapes.size(0);
118
+
119
+ const int num_query = sampling_loc.size(1);
120
+ const int num_point = sampling_loc.size(4);
121
+
122
+ const int im2col_step_ = std::min(batch, static_cast<int>(im2col_step));
123
+
124
+ AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
125
+
126
+ auto grad_value = at::zeros_like(value);
127
+ auto grad_sampling_loc = at::zeros_like(sampling_loc);
128
+ auto grad_attn_weight = at::zeros_like(attn_weight);
129
+
130
+ const int batch_n = im2col_step_;
131
+ auto per_value_size = spatial_size * num_heads * channels;
132
+ auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
133
+ auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
134
+ auto grad_output_n = grad_output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
135
+
136
+ for (int n = 0; n < batch/im2col_step_; ++n)
137
+ {
138
+ auto grad_output_g = grad_output_n.select(0, n);
139
+ AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, value.scalar_type(), "ms_deform_attn_backward_cuda", ([&] {
140
+ ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(),
141
+ grad_output_g.data_ptr<scalar_t>(),
142
+ value.data_ptr<scalar_t>() + n * im2col_step_ * per_value_size,
143
+ spatial_shapes.data_ptr<int64_t>(),
144
+ level_start_index.data_ptr<int64_t>(),
145
+ sampling_loc.data_ptr<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
146
+ attn_weight.data_ptr<scalar_t>() + n * im2col_step_ * per_attn_weight_size,
147
+ batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
148
+ grad_value.data_ptr<scalar_t>() + n * im2col_step_ * per_value_size,
149
+ grad_sampling_loc.data_ptr<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
150
+ grad_attn_weight.data_ptr<scalar_t>() + n * im2col_step_ * per_attn_weight_size);
151
+
152
+ }));
153
+ }
154
+
155
+ return {
156
+ grad_value, grad_sampling_loc, grad_attn_weight
157
+ };
158
+ }
deformable_detr/ms_deform_attn_cuda.cuh ADDED
@@ -0,0 +1,1467 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*!
2
+ **************************************************************************************************
3
+ * Deformable DETR
4
+ * Copyright (c) 2020 SenseTime. All Rights Reserved.
5
+ * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ **************************************************************************************************
7
+ * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8
+ **************************************************************************************************
9
+ */
10
+
11
+ #include <vector>
12
+
13
+ #include <cuda.h>
14
+ #include <cuda_runtime.h>
15
+
16
+ #include <cstdio>
17
+ #include <algorithm>
18
+ #include <cstring>
19
+
20
+ #include <ATen/ATen.h>
21
+ #include <ATen/cuda/CUDAContext.h>
22
+
23
+ #include <THC/THCAtomics.cuh>
24
+
25
+ #define CUDA_KERNEL_LOOP(i, n) \
26
+ for (int i = blockIdx.x * blockDim.x + threadIdx.x; \
27
+ i < (n); \
28
+ i += blockDim.x * gridDim.x)
29
+
30
+
31
+ at::Tensor ms_deform_attn_cuda_forward(
32
+ const at::Tensor &value,
33
+ const at::Tensor &spatial_shapes,
34
+ const at::Tensor &level_start_index,
35
+ const at::Tensor &sampling_loc,
36
+ const at::Tensor &attn_weight,
37
+ const int im2col_step)
38
+ {
39
+ AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
40
+ AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
41
+ AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
42
+ AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
43
+ AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
44
+
45
+ AT_ASSERTM(value.is_cuda(), "value must be a CUDA tensor");
46
+ AT_ASSERTM(spatial_shapes.is_cuda(), "spatial_shapes must be a CUDA tensor");
47
+ AT_ASSERTM(level_start_index.is_cuda(), "level_start_index must be a CUDA tensor");
48
+ AT_ASSERTM(sampling_loc.is_cuda(), "sampling_loc must be a CUDA tensor");
49
+ AT_ASSERTM(attn_weight.is_cuda(), "attn_weight must be a CUDA tensor");
50
+
51
+ const int batch = value.size(0);
52
+ const int spatial_size = value.size(1);
53
+ const int num_heads = value.size(2);
54
+ const int channels = value.size(3);
55
+
56
+ const int num_levels = spatial_shapes.size(0);
57
+
58
+ const int num_query = sampling_loc.size(1);
59
+ const int num_point = sampling_loc.size(4);
60
+
61
+ const int im2col_step_ = std::min(batch, im2col_step);
62
+
63
+ AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
64
+
65
+ auto output = at::zeros({batch, num_query, num_heads, channels}, value.options());
66
+
67
+ const int batch_n = im2col_step_;
68
+ auto output_n = output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
69
+ auto per_value_size = spatial_size * num_heads * channels;
70
+ auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
71
+ auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
72
+ for (int n = 0; n < batch/im2col_step_; ++n)
73
+ {
74
+ auto columns = output_n.select(0, n);
75
+ AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, value.scalar_type(), "ms_deform_attn_forward_cuda", ([&] {
76
+ ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(),
77
+ value.data_ptr<scalar_t>() + n * im2col_step_ * per_value_size,
78
+ spatial_shapes.data_ptr<int64_t>(),
79
+ level_start_index.data_ptr<int64_t>(),
80
+ sampling_loc.data_ptr<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
81
+ attn_weight.data_ptr<scalar_t>() + n * im2col_step_ * per_attn_weight_size,
82
+ batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
83
+ columns.data_ptr<scalar_t>());
84
+
85
+ }));
86
+ }
87
+
88
+ output = output.view({batch, num_query, num_heads*channels});
89
+
90
+ return output;
91
+ }
92
+
93
+
94
+ std::vector<at::Tensor> ms_deform_attn_cuda_backward(
95
+ const at::Tensor &value,
96
+ const at::Tensor &spatial_shapes,
97
+ const at::Tensor &level_start_index,
98
+ const at::Tensor &sampling_loc,
99
+ const at::Tensor &attn_weight,
100
+ const at::Tensor &grad_output,
101
+ const int im2col_step)
102
+ {
103
+
104
+ AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
105
+ AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
106
+ AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
107
+ AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
108
+ AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
109
+ AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous");
110
+
111
+ AT_ASSERTM(value.is_cuda(), "value must be a CUDA tensor");
112
+ AT_ASSERTM(spatial_shapes.is_cuda(), "spatial_shapes must be a CUDA tensor");
113
+ AT_ASSERTM(level_start_index.is_cuda(), "level_start_index must be a CUDA tensor");
114
+ AT_ASSERTM(sampling_loc.is_cuda(), "sampling_loc must be a CUDA tensor");
115
+ AT_ASSERTM(attn_weight.is_cuda(), "attn_weight must be a CUDA tensor");
116
+ AT_ASSERTM(grad_output.is_cuda(), "grad_output must be a CUDA tensor");
117
+
118
+ const int batch = value.size(0);
119
+ const int spatial_size = value.size(1);
120
+ const int num_heads = value.size(2);
121
+ const int channels = value.size(3);
122
+
123
+ const int num_levels = spatial_shapes.size(0);
124
+
125
+ const int num_query = sampling_loc.size(1);
126
+ const int num_point = sampling_loc.size(4);
127
+
128
+ const int im2col_step_ = std::min(batch, im2col_step);
129
+
130
+ AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
131
+
132
+ auto grad_value = at::zeros_like(value);
133
+ auto grad_sampling_loc = at::zeros_like(sampling_loc);
134
+ auto grad_attn_weight = at::zeros_like(attn_weight);
135
+
136
+ const int batch_n = im2col_step_;
137
+ auto per_value_size = spatial_size * num_heads * channels;
138
+ auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
139
+ auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
140
+ auto grad_output_n = grad_output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
141
+
142
+ for (int n = 0; n < batch/im2col_step_; ++n)
143
+ {
144
+ auto grad_output_g = grad_output_n.select(0, n);
145
+ AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, value.scalar_type(), "ms_deform_attn_backward_cuda", ([&] {
146
+ ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(),
147
+ grad_output_g.data_ptr<scalar_t>(),
148
+ value.data_ptr<scalar_t>() + n * im2col_step_ * per_value_size,
149
+ spatial_shapes.data_ptr<int64_t>(),
150
+ level_start_index.data_ptr<int64_t>(),
151
+ sampling_loc.data_ptr<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
152
+ attn_weight.data_ptr<scalar_t>() + n * im2col_step_ * per_attn_weight_size,
153
+ batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
154
+ grad_value.data_ptr<scalar_t>() + n * im2col_step_ * per_value_size,
155
+ grad_sampling_loc.data_ptr<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
156
+ grad_attn_weight.data_ptr<scalar_t>() + n * im2col_step_ * per_attn_weight_size);
157
+
158
+ }));
159
+ }
160
+
161
+ return {
162
+ grad_value, grad_sampling_loc, grad_attn_weight
163
+ };
164
+ }
165
+
166
+ const int CUDA_NUM_THREADS = 1024;
167
+ inline int GET_BLOCKS(const int N, const int num_threads)
168
+ {
169
+ return (N + num_threads - 1) / num_threads;
170
+ }
171
+
172
+
173
+ template <typename scalar_t>
174
+ __device__ scalar_t ms_deform_attn_im2col_bilinear(const scalar_t* &bottom_data,
175
+ const int &height, const int &width, const int &nheads, const int &channels,
176
+ const scalar_t &h, const scalar_t &w, const int &m, const int &c)
177
+ {
178
+ const int h_low = floor(h);
179
+ const int w_low = floor(w);
180
+ const int h_high = h_low + 1;
181
+ const int w_high = w_low + 1;
182
+
183
+ const scalar_t lh = h - h_low;
184
+ const scalar_t lw = w - w_low;
185
+ const scalar_t hh = 1 - lh, hw = 1 - lw;
186
+
187
+ const int w_stride = nheads * channels;
188
+ const int h_stride = width * w_stride;
189
+ const int h_low_ptr_offset = h_low * h_stride;
190
+ const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
191
+ const int w_low_ptr_offset = w_low * w_stride;
192
+ const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
193
+ const int base_ptr = m * channels + c;
194
+
195
+ scalar_t v1 = 0;
196
+ if (h_low >= 0 && w_low >= 0)
197
+ {
198
+ const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
199
+ v1 = bottom_data[ptr1];
200
+ }
201
+ scalar_t v2 = 0;
202
+ if (h_low >= 0 && w_high <= width - 1)
203
+ {
204
+ const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
205
+ v2 = bottom_data[ptr2];
206
+ }
207
+ scalar_t v3 = 0;
208
+ if (h_high <= height - 1 && w_low >= 0)
209
+ {
210
+ const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
211
+ v3 = bottom_data[ptr3];
212
+ }
213
+ scalar_t v4 = 0;
214
+ if (h_high <= height - 1 && w_high <= width - 1)
215
+ {
216
+ const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
217
+ v4 = bottom_data[ptr4];
218
+ }
219
+
220
+ const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
221
+
222
+ const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
223
+ return val;
224
+ }
225
+
226
+
227
+ template <typename scalar_t>
228
+ __device__ void ms_deform_attn_col2im_bilinear(const scalar_t* &bottom_data,
229
+ const int &height, const int &width, const int &nheads, const int &channels,
230
+ const scalar_t &h, const scalar_t &w, const int &m, const int &c,
231
+ const scalar_t &top_grad,
232
+ const scalar_t &attn_weight,
233
+ scalar_t* &grad_value,
234
+ scalar_t* grad_sampling_loc,
235
+ scalar_t* grad_attn_weight)
236
+ {
237
+ const int h_low = floor(h);
238
+ const int w_low = floor(w);
239
+ const int h_high = h_low + 1;
240
+ const int w_high = w_low + 1;
241
+
242
+ const scalar_t lh = h - h_low;
243
+ const scalar_t lw = w - w_low;
244
+ const scalar_t hh = 1 - lh, hw = 1 - lw;
245
+
246
+ const int w_stride = nheads * channels;
247
+ const int h_stride = width * w_stride;
248
+ const int h_low_ptr_offset = h_low * h_stride;
249
+ const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
250
+ const int w_low_ptr_offset = w_low * w_stride;
251
+ const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
252
+ const int base_ptr = m * channels + c;
253
+
254
+ const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
255
+ const scalar_t top_grad_value = top_grad * attn_weight;
256
+ scalar_t grad_h_weight = 0, grad_w_weight = 0;
257
+
258
+ scalar_t v1 = 0;
259
+ if (h_low >= 0 && w_low >= 0)
260
+ {
261
+ const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
262
+ v1 = bottom_data[ptr1];
263
+ grad_h_weight -= hw * v1;
264
+ grad_w_weight -= hh * v1;
265
+ atomicAdd(grad_value+ptr1, w1*top_grad_value);
266
+ }
267
+ scalar_t v2 = 0;
268
+ if (h_low >= 0 && w_high <= width - 1)
269
+ {
270
+ const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
271
+ v2 = bottom_data[ptr2];
272
+ grad_h_weight -= lw * v2;
273
+ grad_w_weight += hh * v2;
274
+ atomicAdd(grad_value+ptr2, w2*top_grad_value);
275
+ }
276
+ scalar_t v3 = 0;
277
+ if (h_high <= height - 1 && w_low >= 0)
278
+ {
279
+ const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
280
+ v3 = bottom_data[ptr3];
281
+ grad_h_weight += hw * v3;
282
+ grad_w_weight -= lh * v3;
283
+ atomicAdd(grad_value+ptr3, w3*top_grad_value);
284
+ }
285
+ scalar_t v4 = 0;
286
+ if (h_high <= height - 1 && w_high <= width - 1)
287
+ {
288
+ const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
289
+ v4 = bottom_data[ptr4];
290
+ grad_h_weight += lw * v4;
291
+ grad_w_weight += lh * v4;
292
+ atomicAdd(grad_value+ptr4, w4*top_grad_value);
293
+ }
294
+
295
+ const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
296
+ *grad_attn_weight = top_grad * val;
297
+ *grad_sampling_loc = width * grad_w_weight * top_grad_value;
298
+ *(grad_sampling_loc + 1) = height * grad_h_weight * top_grad_value;
299
+ }
300
+
301
+
302
+ template <typename scalar_t>
303
+ __device__ void ms_deform_attn_col2im_bilinear_gm(const scalar_t* &bottom_data,
304
+ const int &height, const int &width, const int &nheads, const int &channels,
305
+ const scalar_t &h, const scalar_t &w, const int &m, const int &c,
306
+ const scalar_t &top_grad,
307
+ const scalar_t &attn_weight,
308
+ scalar_t* &grad_value,
309
+ scalar_t* grad_sampling_loc,
310
+ scalar_t* grad_attn_weight)
311
+ {
312
+ const int h_low = floor(h);
313
+ const int w_low = floor(w);
314
+ const int h_high = h_low + 1;
315
+ const int w_high = w_low + 1;
316
+
317
+ const scalar_t lh = h - h_low;
318
+ const scalar_t lw = w - w_low;
319
+ const scalar_t hh = 1 - lh, hw = 1 - lw;
320
+
321
+ const int w_stride = nheads * channels;
322
+ const int h_stride = width * w_stride;
323
+ const int h_low_ptr_offset = h_low * h_stride;
324
+ const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
325
+ const int w_low_ptr_offset = w_low * w_stride;
326
+ const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
327
+ const int base_ptr = m * channels + c;
328
+
329
+ const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
330
+ const scalar_t top_grad_value = top_grad * attn_weight;
331
+ scalar_t grad_h_weight = 0, grad_w_weight = 0;
332
+
333
+ scalar_t v1 = 0;
334
+ if (h_low >= 0 && w_low >= 0)
335
+ {
336
+ const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
337
+ v1 = bottom_data[ptr1];
338
+ grad_h_weight -= hw * v1;
339
+ grad_w_weight -= hh * v1;
340
+ atomicAdd(grad_value+ptr1, w1*top_grad_value);
341
+ }
342
+ scalar_t v2 = 0;
343
+ if (h_low >= 0 && w_high <= width - 1)
344
+ {
345
+ const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
346
+ v2 = bottom_data[ptr2];
347
+ grad_h_weight -= lw * v2;
348
+ grad_w_weight += hh * v2;
349
+ atomicAdd(grad_value+ptr2, w2*top_grad_value);
350
+ }
351
+ scalar_t v3 = 0;
352
+ if (h_high <= height - 1 && w_low >= 0)
353
+ {
354
+ const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
355
+ v3 = bottom_data[ptr3];
356
+ grad_h_weight += hw * v3;
357
+ grad_w_weight -= lh * v3;
358
+ atomicAdd(grad_value+ptr3, w3*top_grad_value);
359
+ }
360
+ scalar_t v4 = 0;
361
+ if (h_high <= height - 1 && w_high <= width - 1)
362
+ {
363
+ const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
364
+ v4 = bottom_data[ptr4];
365
+ grad_h_weight += lw * v4;
366
+ grad_w_weight += lh * v4;
367
+ atomicAdd(grad_value+ptr4, w4*top_grad_value);
368
+ }
369
+
370
+ const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
371
+ atomicAdd(grad_attn_weight, top_grad * val);
372
+ atomicAdd(grad_sampling_loc, width * grad_w_weight * top_grad_value);
373
+ atomicAdd(grad_sampling_loc + 1, height * grad_h_weight * top_grad_value);
374
+ }
375
+
376
+
377
+ template <typename scalar_t>
378
+ __global__ void ms_deformable_im2col_gpu_kernel(const int n,
379
+ const scalar_t *data_value,
380
+ const int64_t *data_spatial_shapes,
381
+ const int64_t *data_level_start_index,
382
+ const scalar_t *data_sampling_loc,
383
+ const scalar_t *data_attn_weight,
384
+ const int batch_size,
385
+ const int spatial_size,
386
+ const int num_heads,
387
+ const int channels,
388
+ const int num_levels,
389
+ const int num_query,
390
+ const int num_point,
391
+ scalar_t *data_col)
392
+ {
393
+ CUDA_KERNEL_LOOP(index, n)
394
+ {
395
+ int _temp = index;
396
+ const int c_col = _temp % channels;
397
+ _temp /= channels;
398
+ const int sampling_index = _temp;
399
+ const int m_col = _temp % num_heads;
400
+ _temp /= num_heads;
401
+ [[maybe_unused]] const int q_col = _temp % num_query;
402
+ _temp /= num_query;
403
+ const int b_col = _temp;
404
+
405
+ scalar_t *data_col_ptr = data_col + index;
406
+ int data_weight_ptr = sampling_index * num_levels * num_point;
407
+ int data_loc_w_ptr = data_weight_ptr << 1;
408
+ const int qid_stride = num_heads * channels;
409
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
410
+ scalar_t col = 0;
411
+
412
+ for (int l_col=0; l_col < num_levels; ++l_col)
413
+ {
414
+ const int level_start_id = data_level_start_index[l_col];
415
+ const int spatial_h_ptr = l_col << 1;
416
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
417
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
418
+ const scalar_t *data_value_ptr = data_value + (data_value_ptr_init_offset + level_start_id * qid_stride);
419
+ for (int p_col=0; p_col < num_point; ++p_col)
420
+ {
421
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
422
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
423
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
424
+
425
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
426
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
427
+
428
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
429
+ {
430
+ col += ms_deform_attn_im2col_bilinear(data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col) * weight;
431
+ }
432
+
433
+ data_weight_ptr += 1;
434
+ data_loc_w_ptr += 2;
435
+ }
436
+ }
437
+ *data_col_ptr = col;
438
+ }
439
+ }
440
+
441
+ template <typename scalar_t, unsigned int blockSize>
442
+ __global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1(const int n,
443
+ const scalar_t *grad_col,
444
+ const scalar_t *data_value,
445
+ const int64_t *data_spatial_shapes,
446
+ const int64_t *data_level_start_index,
447
+ const scalar_t *data_sampling_loc,
448
+ const scalar_t *data_attn_weight,
449
+ const int batch_size,
450
+ const int spatial_size,
451
+ const int num_heads,
452
+ const int channels,
453
+ const int num_levels,
454
+ const int num_query,
455
+ const int num_point,
456
+ scalar_t *grad_value,
457
+ scalar_t *grad_sampling_loc,
458
+ scalar_t *grad_attn_weight)
459
+ {
460
+ CUDA_KERNEL_LOOP(index, n)
461
+ {
462
+ __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2];
463
+ __shared__ scalar_t cache_grad_attn_weight[blockSize];
464
+ unsigned int tid = threadIdx.x;
465
+ int _temp = index;
466
+ const int c_col = _temp % channels;
467
+ _temp /= channels;
468
+ const int sampling_index = _temp;
469
+ const int m_col = _temp % num_heads;
470
+ _temp /= num_heads;
471
+ [[maybe_unused]] const int q_col = _temp % num_query;
472
+ _temp /= num_query;
473
+ const int b_col = _temp;
474
+
475
+ const scalar_t top_grad = grad_col[index];
476
+
477
+ int data_weight_ptr = sampling_index * num_levels * num_point;
478
+ int data_loc_w_ptr = data_weight_ptr << 1;
479
+ const int grad_sampling_ptr = data_weight_ptr;
480
+ grad_sampling_loc += grad_sampling_ptr << 1;
481
+ grad_attn_weight += grad_sampling_ptr;
482
+ const int grad_weight_stride = 1;
483
+ const int grad_loc_stride = 2;
484
+ const int qid_stride = num_heads * channels;
485
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
486
+
487
+ for (int l_col=0; l_col < num_levels; ++l_col)
488
+ {
489
+ const int level_start_id = data_level_start_index[l_col];
490
+ const int spatial_h_ptr = l_col << 1;
491
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
492
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
493
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
494
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
495
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
496
+
497
+ for (int p_col=0; p_col < num_point; ++p_col)
498
+ {
499
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
500
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
501
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
502
+
503
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
504
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
505
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
506
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
507
+ *(cache_grad_attn_weight+threadIdx.x)=0;
508
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
509
+ {
510
+ ms_deform_attn_col2im_bilinear(
511
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
512
+ top_grad, weight, grad_value_ptr,
513
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
514
+ }
515
+
516
+ __syncthreads();
517
+ if (tid == 0)
518
+ {
519
+ scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0];
520
+ int sid=2;
521
+ for (unsigned int tid = 1; tid < blockSize; ++tid)
522
+ {
523
+ _grad_w += cache_grad_sampling_loc[sid];
524
+ _grad_h += cache_grad_sampling_loc[sid + 1];
525
+ _grad_a += cache_grad_attn_weight[tid];
526
+ sid += 2;
527
+ }
528
+
529
+
530
+ *grad_sampling_loc = _grad_w;
531
+ *(grad_sampling_loc + 1) = _grad_h;
532
+ *grad_attn_weight = _grad_a;
533
+ }
534
+ __syncthreads();
535
+
536
+ data_weight_ptr += 1;
537
+ data_loc_w_ptr += 2;
538
+ grad_attn_weight += grad_weight_stride;
539
+ grad_sampling_loc += grad_loc_stride;
540
+ }
541
+ }
542
+ }
543
+ }
544
+
545
+
546
+ template <typename scalar_t, unsigned int blockSize>
547
+ __global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2(const int n,
548
+ const scalar_t *grad_col,
549
+ const scalar_t *data_value,
550
+ const int64_t *data_spatial_shapes,
551
+ const int64_t *data_level_start_index,
552
+ const scalar_t *data_sampling_loc,
553
+ const scalar_t *data_attn_weight,
554
+ const int batch_size,
555
+ const int spatial_size,
556
+ const int num_heads,
557
+ const int channels,
558
+ const int num_levels,
559
+ const int num_query,
560
+ const int num_point,
561
+ scalar_t *grad_value,
562
+ scalar_t *grad_sampling_loc,
563
+ scalar_t *grad_attn_weight)
564
+ {
565
+ CUDA_KERNEL_LOOP(index, n)
566
+ {
567
+ __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2];
568
+ __shared__ scalar_t cache_grad_attn_weight[blockSize];
569
+ unsigned int tid = threadIdx.x;
570
+ int _temp = index;
571
+ const int c_col = _temp % channels;
572
+ _temp /= channels;
573
+ const int sampling_index = _temp;
574
+ const int m_col = _temp % num_heads;
575
+ _temp /= num_heads;
576
+ [[maybe_unused]] const int q_col = _temp % num_query;
577
+ _temp /= num_query;
578
+ const int b_col = _temp;
579
+
580
+ const scalar_t top_grad = grad_col[index];
581
+
582
+ int data_weight_ptr = sampling_index * num_levels * num_point;
583
+ int data_loc_w_ptr = data_weight_ptr << 1;
584
+ const int grad_sampling_ptr = data_weight_ptr;
585
+ grad_sampling_loc += grad_sampling_ptr << 1;
586
+ grad_attn_weight += grad_sampling_ptr;
587
+ const int grad_weight_stride = 1;
588
+ const int grad_loc_stride = 2;
589
+ const int qid_stride = num_heads * channels;
590
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
591
+
592
+ for (int l_col=0; l_col < num_levels; ++l_col)
593
+ {
594
+ const int level_start_id = data_level_start_index[l_col];
595
+ const int spatial_h_ptr = l_col << 1;
596
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
597
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
598
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
599
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
600
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
601
+
602
+ for (int p_col=0; p_col < num_point; ++p_col)
603
+ {
604
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
605
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
606
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
607
+
608
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
609
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
610
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
611
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
612
+ *(cache_grad_attn_weight+threadIdx.x)=0;
613
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
614
+ {
615
+ ms_deform_attn_col2im_bilinear(
616
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
617
+ top_grad, weight, grad_value_ptr,
618
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
619
+ }
620
+
621
+ __syncthreads();
622
+
623
+ for (unsigned int s=blockSize/2; s>0; s>>=1)
624
+ {
625
+ if (tid < s) {
626
+ const unsigned int xid1 = tid << 1;
627
+ const unsigned int xid2 = (tid + s) << 1;
628
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
629
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
630
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
631
+ }
632
+ __syncthreads();
633
+ }
634
+
635
+ if (tid == 0)
636
+ {
637
+ *grad_sampling_loc = cache_grad_sampling_loc[0];
638
+ *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1];
639
+ *grad_attn_weight = cache_grad_attn_weight[0];
640
+ }
641
+ __syncthreads();
642
+
643
+ data_weight_ptr += 1;
644
+ data_loc_w_ptr += 2;
645
+ grad_attn_weight += grad_weight_stride;
646
+ grad_sampling_loc += grad_loc_stride;
647
+ }
648
+ }
649
+ }
650
+ }
651
+
652
+
653
+ template <typename scalar_t>
654
+ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v1(const int n,
655
+ const scalar_t *grad_col,
656
+ const scalar_t *data_value,
657
+ const int64_t *data_spatial_shapes,
658
+ const int64_t *data_level_start_index,
659
+ const scalar_t *data_sampling_loc,
660
+ const scalar_t *data_attn_weight,
661
+ const int batch_size,
662
+ const int spatial_size,
663
+ const int num_heads,
664
+ const int channels,
665
+ const int num_levels,
666
+ const int num_query,
667
+ const int num_point,
668
+ scalar_t *grad_value,
669
+ scalar_t *grad_sampling_loc,
670
+ scalar_t *grad_attn_weight)
671
+ {
672
+ CUDA_KERNEL_LOOP(index, n)
673
+ {
674
+ extern __shared__ int _s[];
675
+ scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
676
+ scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
677
+ unsigned int tid = threadIdx.x;
678
+ int _temp = index;
679
+ const int c_col = _temp % channels;
680
+ _temp /= channels;
681
+ const int sampling_index = _temp;
682
+ const int m_col = _temp % num_heads;
683
+ _temp /= num_heads;
684
+ [[maybe_unused]] const int q_col = _temp % num_query;
685
+ _temp /= num_query;
686
+ const int b_col = _temp;
687
+
688
+ const scalar_t top_grad = grad_col[index];
689
+
690
+ int data_weight_ptr = sampling_index * num_levels * num_point;
691
+ int data_loc_w_ptr = data_weight_ptr << 1;
692
+ const int grad_sampling_ptr = data_weight_ptr;
693
+ grad_sampling_loc += grad_sampling_ptr << 1;
694
+ grad_attn_weight += grad_sampling_ptr;
695
+ const int grad_weight_stride = 1;
696
+ const int grad_loc_stride = 2;
697
+ const int qid_stride = num_heads * channels;
698
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
699
+
700
+ for (int l_col=0; l_col < num_levels; ++l_col)
701
+ {
702
+ const int level_start_id = data_level_start_index[l_col];
703
+ const int spatial_h_ptr = l_col << 1;
704
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
705
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
706
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
707
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
708
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
709
+
710
+ for (int p_col=0; p_col < num_point; ++p_col)
711
+ {
712
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
713
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
714
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
715
+
716
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
717
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
718
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
719
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
720
+ *(cache_grad_attn_weight+threadIdx.x)=0;
721
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
722
+ {
723
+ ms_deform_attn_col2im_bilinear(
724
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
725
+ top_grad, weight, grad_value_ptr,
726
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
727
+ }
728
+
729
+ __syncthreads();
730
+ if (tid == 0)
731
+ {
732
+ scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0];
733
+ int sid=2;
734
+ for (unsigned int tid = 1; tid < blockDim.x; ++tid)
735
+ {
736
+ _grad_w += cache_grad_sampling_loc[sid];
737
+ _grad_h += cache_grad_sampling_loc[sid + 1];
738
+ _grad_a += cache_grad_attn_weight[tid];
739
+ sid += 2;
740
+ }
741
+
742
+
743
+ *grad_sampling_loc = _grad_w;
744
+ *(grad_sampling_loc + 1) = _grad_h;
745
+ *grad_attn_weight = _grad_a;
746
+ }
747
+ __syncthreads();
748
+
749
+ data_weight_ptr += 1;
750
+ data_loc_w_ptr += 2;
751
+ grad_attn_weight += grad_weight_stride;
752
+ grad_sampling_loc += grad_loc_stride;
753
+ }
754
+ }
755
+ }
756
+ }
757
+
758
+ template <typename scalar_t>
759
+ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2(const int n,
760
+ const scalar_t *grad_col,
761
+ const scalar_t *data_value,
762
+ const int64_t *data_spatial_shapes,
763
+ const int64_t *data_level_start_index,
764
+ const scalar_t *data_sampling_loc,
765
+ const scalar_t *data_attn_weight,
766
+ const int batch_size,
767
+ const int spatial_size,
768
+ const int num_heads,
769
+ const int channels,
770
+ const int num_levels,
771
+ const int num_query,
772
+ const int num_point,
773
+ scalar_t *grad_value,
774
+ scalar_t *grad_sampling_loc,
775
+ scalar_t *grad_attn_weight)
776
+ {
777
+ CUDA_KERNEL_LOOP(index, n)
778
+ {
779
+ extern __shared__ int _s[];
780
+ scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
781
+ scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
782
+ unsigned int tid = threadIdx.x;
783
+ int _temp = index;
784
+ const int c_col = _temp % channels;
785
+ _temp /= channels;
786
+ const int sampling_index = _temp;
787
+ const int m_col = _temp % num_heads;
788
+ _temp /= num_heads;
789
+ [[maybe_unused]] const int q_col = _temp % num_query;
790
+ _temp /= num_query;
791
+ const int b_col = _temp;
792
+
793
+ const scalar_t top_grad = grad_col[index];
794
+
795
+ int data_weight_ptr = sampling_index * num_levels * num_point;
796
+ int data_loc_w_ptr = data_weight_ptr << 1;
797
+ const int grad_sampling_ptr = data_weight_ptr;
798
+ grad_sampling_loc += grad_sampling_ptr << 1;
799
+ grad_attn_weight += grad_sampling_ptr;
800
+ const int grad_weight_stride = 1;
801
+ const int grad_loc_stride = 2;
802
+ const int qid_stride = num_heads * channels;
803
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
804
+
805
+ for (int l_col=0; l_col < num_levels; ++l_col)
806
+ {
807
+ const int level_start_id = data_level_start_index[l_col];
808
+ const int spatial_h_ptr = l_col << 1;
809
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
810
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
811
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
812
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
813
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
814
+
815
+ for (int p_col=0; p_col < num_point; ++p_col)
816
+ {
817
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
818
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
819
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
820
+
821
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
822
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
823
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
824
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
825
+ *(cache_grad_attn_weight+threadIdx.x)=0;
826
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
827
+ {
828
+ ms_deform_attn_col2im_bilinear(
829
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
830
+ top_grad, weight, grad_value_ptr,
831
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
832
+ }
833
+
834
+ __syncthreads();
835
+
836
+ for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1)
837
+ {
838
+ if (tid < s) {
839
+ const unsigned int xid1 = tid << 1;
840
+ const unsigned int xid2 = (tid + s) << 1;
841
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
842
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
843
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
844
+ if (tid + (s << 1) < spre)
845
+ {
846
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)];
847
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)];
848
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)];
849
+ }
850
+ }
851
+ __syncthreads();
852
+ }
853
+
854
+ if (tid == 0)
855
+ {
856
+ *grad_sampling_loc = cache_grad_sampling_loc[0];
857
+ *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1];
858
+ *grad_attn_weight = cache_grad_attn_weight[0];
859
+ }
860
+ __syncthreads();
861
+
862
+ data_weight_ptr += 1;
863
+ data_loc_w_ptr += 2;
864
+ grad_attn_weight += grad_weight_stride;
865
+ grad_sampling_loc += grad_loc_stride;
866
+ }
867
+ }
868
+ }
869
+ }
870
+
871
+ template <typename scalar_t>
872
+ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks(const int n,
873
+ const scalar_t *grad_col,
874
+ const scalar_t *data_value,
875
+ const int64_t *data_spatial_shapes,
876
+ const int64_t *data_level_start_index,
877
+ const scalar_t *data_sampling_loc,
878
+ const scalar_t *data_attn_weight,
879
+ const int batch_size,
880
+ const int spatial_size,
881
+ const int num_heads,
882
+ const int channels,
883
+ const int num_levels,
884
+ const int num_query,
885
+ const int num_point,
886
+ scalar_t *grad_value,
887
+ scalar_t *grad_sampling_loc,
888
+ scalar_t *grad_attn_weight)
889
+ {
890
+ CUDA_KERNEL_LOOP(index, n)
891
+ {
892
+ extern __shared__ int _s[];
893
+ scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
894
+ scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
895
+ unsigned int tid = threadIdx.x;
896
+ int _temp = index;
897
+ const int c_col = _temp % channels;
898
+ _temp /= channels;
899
+ const int sampling_index = _temp;
900
+ const int m_col = _temp % num_heads;
901
+ _temp /= num_heads;
902
+ [[maybe_unused]] const int q_col = _temp % num_query;
903
+ _temp /= num_query;
904
+ const int b_col = _temp;
905
+
906
+ const scalar_t top_grad = grad_col[index];
907
+
908
+ int data_weight_ptr = sampling_index * num_levels * num_point;
909
+ int data_loc_w_ptr = data_weight_ptr << 1;
910
+ const int grad_sampling_ptr = data_weight_ptr;
911
+ grad_sampling_loc += grad_sampling_ptr << 1;
912
+ grad_attn_weight += grad_sampling_ptr;
913
+ const int grad_weight_stride = 1;
914
+ const int grad_loc_stride = 2;
915
+ const int qid_stride = num_heads * channels;
916
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
917
+
918
+ for (int l_col=0; l_col < num_levels; ++l_col)
919
+ {
920
+ const int level_start_id = data_level_start_index[l_col];
921
+ const int spatial_h_ptr = l_col << 1;
922
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
923
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
924
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
925
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
926
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
927
+
928
+ for (int p_col=0; p_col < num_point; ++p_col)
929
+ {
930
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
931
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
932
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
933
+
934
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
935
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
936
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
937
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
938
+ *(cache_grad_attn_weight+threadIdx.x)=0;
939
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
940
+ {
941
+ ms_deform_attn_col2im_bilinear(
942
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
943
+ top_grad, weight, grad_value_ptr,
944
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
945
+ }
946
+
947
+ __syncthreads();
948
+
949
+ for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1)
950
+ {
951
+ if (tid < s) {
952
+ const unsigned int xid1 = tid << 1;
953
+ const unsigned int xid2 = (tid + s) << 1;
954
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
955
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
956
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
957
+ if (tid + (s << 1) < spre)
958
+ {
959
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)];
960
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)];
961
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)];
962
+ }
963
+ }
964
+ __syncthreads();
965
+ }
966
+
967
+ if (tid == 0)
968
+ {
969
+ atomicAdd(grad_sampling_loc, cache_grad_sampling_loc[0]);
970
+ atomicAdd(grad_sampling_loc + 1, cache_grad_sampling_loc[1]);
971
+ atomicAdd(grad_attn_weight, cache_grad_attn_weight[0]);
972
+ }
973
+ __syncthreads();
974
+
975
+ data_weight_ptr += 1;
976
+ data_loc_w_ptr += 2;
977
+ grad_attn_weight += grad_weight_stride;
978
+ grad_sampling_loc += grad_loc_stride;
979
+ }
980
+ }
981
+ }
982
+ }
983
+
984
+
985
+ template <typename scalar_t>
986
+ __global__ void ms_deformable_col2im_gpu_kernel_gm(const int n,
987
+ const scalar_t *grad_col,
988
+ const scalar_t *data_value,
989
+ const int64_t *data_spatial_shapes,
990
+ const int64_t *data_level_start_index,
991
+ const scalar_t *data_sampling_loc,
992
+ const scalar_t *data_attn_weight,
993
+ const int batch_size,
994
+ const int spatial_size,
995
+ const int num_heads,
996
+ const int channels,
997
+ const int num_levels,
998
+ const int num_query,
999
+ const int num_point,
1000
+ scalar_t *grad_value,
1001
+ scalar_t *grad_sampling_loc,
1002
+ scalar_t *grad_attn_weight)
1003
+ {
1004
+ CUDA_KERNEL_LOOP(index, n)
1005
+ {
1006
+ int _temp = index;
1007
+ const int c_col = _temp % channels;
1008
+ _temp /= channels;
1009
+ const int sampling_index = _temp;
1010
+ const int m_col = _temp % num_heads;
1011
+ _temp /= num_heads;
1012
+ [[maybe_unused]] const int q_col = _temp % num_query;
1013
+ _temp /= num_query;
1014
+ const int b_col = _temp;
1015
+
1016
+ const scalar_t top_grad = grad_col[index];
1017
+
1018
+ int data_weight_ptr = sampling_index * num_levels * num_point;
1019
+ int data_loc_w_ptr = data_weight_ptr << 1;
1020
+ const int grad_sampling_ptr = data_weight_ptr;
1021
+ grad_sampling_loc += grad_sampling_ptr << 1;
1022
+ grad_attn_weight += grad_sampling_ptr;
1023
+ const int grad_weight_stride = 1;
1024
+ const int grad_loc_stride = 2;
1025
+ const int qid_stride = num_heads * channels;
1026
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
1027
+
1028
+ for (int l_col=0; l_col < num_levels; ++l_col)
1029
+ {
1030
+ const int level_start_id = data_level_start_index[l_col];
1031
+ const int spatial_h_ptr = l_col << 1;
1032
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
1033
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
1034
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
1035
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
1036
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
1037
+
1038
+ for (int p_col=0; p_col < num_point; ++p_col)
1039
+ {
1040
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
1041
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
1042
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
1043
+
1044
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
1045
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
1046
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
1047
+ {
1048
+ ms_deform_attn_col2im_bilinear_gm(
1049
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
1050
+ top_grad, weight, grad_value_ptr,
1051
+ grad_sampling_loc, grad_attn_weight);
1052
+ }
1053
+ data_weight_ptr += 1;
1054
+ data_loc_w_ptr += 2;
1055
+ grad_attn_weight += grad_weight_stride;
1056
+ grad_sampling_loc += grad_loc_stride;
1057
+ }
1058
+ }
1059
+ }
1060
+ }
1061
+
1062
+
1063
+ template <typename scalar_t>
1064
+ void ms_deformable_im2col_cuda(cudaStream_t stream,
1065
+ const scalar_t* data_value,
1066
+ const int64_t* data_spatial_shapes,
1067
+ const int64_t* data_level_start_index,
1068
+ const scalar_t* data_sampling_loc,
1069
+ const scalar_t* data_attn_weight,
1070
+ const int batch_size,
1071
+ const int spatial_size,
1072
+ const int num_heads,
1073
+ const int channels,
1074
+ const int num_levels,
1075
+ const int num_query,
1076
+ const int num_point,
1077
+ scalar_t* data_col)
1078
+ {
1079
+ const int num_kernels = batch_size * num_query * num_heads * channels;
1080
+ const int num_actual_kernels = batch_size * num_query * num_heads * channels;
1081
+ const int num_threads = CUDA_NUM_THREADS;
1082
+ ms_deformable_im2col_gpu_kernel<scalar_t>
1083
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1084
+ 0, stream>>>(
1085
+ num_kernels, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight,
1086
+ batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, data_col);
1087
+
1088
+ cudaError_t err = cudaGetLastError();
1089
+ if (err != cudaSuccess)
1090
+ {
1091
+ printf("error in ms_deformable_im2col_cuda: %s\n", cudaGetErrorString(err));
1092
+ }
1093
+
1094
+ }
1095
+
1096
+ template <typename scalar_t>
1097
+ void ms_deformable_col2im_cuda(cudaStream_t stream,
1098
+ const scalar_t* grad_col,
1099
+ const scalar_t* data_value,
1100
+ const int64_t * data_spatial_shapes,
1101
+ const int64_t * data_level_start_index,
1102
+ const scalar_t * data_sampling_loc,
1103
+ const scalar_t * data_attn_weight,
1104
+ const int batch_size,
1105
+ const int spatial_size,
1106
+ const int num_heads,
1107
+ const int channels,
1108
+ const int num_levels,
1109
+ const int num_query,
1110
+ const int num_point,
1111
+ scalar_t* grad_value,
1112
+ scalar_t* grad_sampling_loc,
1113
+ scalar_t* grad_attn_weight)
1114
+ {
1115
+ const int num_threads = (channels > CUDA_NUM_THREADS)?CUDA_NUM_THREADS:channels;
1116
+ const int num_kernels = batch_size * num_query * num_heads * channels;
1117
+ const int num_actual_kernels = batch_size * num_query * num_heads * channels;
1118
+ if (channels > 1024)
1119
+ {
1120
+ if ((channels & 1023) == 0)
1121
+ {
1122
+ ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks<scalar_t>
1123
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1124
+ num_threads*3*sizeof(scalar_t), stream>>>(
1125
+ num_kernels,
1126
+ grad_col,
1127
+ data_value,
1128
+ data_spatial_shapes,
1129
+ data_level_start_index,
1130
+ data_sampling_loc,
1131
+ data_attn_weight,
1132
+ batch_size,
1133
+ spatial_size,
1134
+ num_heads,
1135
+ channels,
1136
+ num_levels,
1137
+ num_query,
1138
+ num_point,
1139
+ grad_value,
1140
+ grad_sampling_loc,
1141
+ grad_attn_weight);
1142
+ }
1143
+ else
1144
+ {
1145
+ ms_deformable_col2im_gpu_kernel_gm<scalar_t>
1146
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1147
+ 0, stream>>>(
1148
+ num_kernels,
1149
+ grad_col,
1150
+ data_value,
1151
+ data_spatial_shapes,
1152
+ data_level_start_index,
1153
+ data_sampling_loc,
1154
+ data_attn_weight,
1155
+ batch_size,
1156
+ spatial_size,
1157
+ num_heads,
1158
+ channels,
1159
+ num_levels,
1160
+ num_query,
1161
+ num_point,
1162
+ grad_value,
1163
+ grad_sampling_loc,
1164
+ grad_attn_weight);
1165
+ }
1166
+ }
1167
+ else{
1168
+ switch(channels)
1169
+ {
1170
+ case 1:
1171
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 1>
1172
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1173
+ 0, stream>>>(
1174
+ num_kernels,
1175
+ grad_col,
1176
+ data_value,
1177
+ data_spatial_shapes,
1178
+ data_level_start_index,
1179
+ data_sampling_loc,
1180
+ data_attn_weight,
1181
+ batch_size,
1182
+ spatial_size,
1183
+ num_heads,
1184
+ channels,
1185
+ num_levels,
1186
+ num_query,
1187
+ num_point,
1188
+ grad_value,
1189
+ grad_sampling_loc,
1190
+ grad_attn_weight);
1191
+ break;
1192
+ case 2:
1193
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 2>
1194
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1195
+ 0, stream>>>(
1196
+ num_kernels,
1197
+ grad_col,
1198
+ data_value,
1199
+ data_spatial_shapes,
1200
+ data_level_start_index,
1201
+ data_sampling_loc,
1202
+ data_attn_weight,
1203
+ batch_size,
1204
+ spatial_size,
1205
+ num_heads,
1206
+ channels,
1207
+ num_levels,
1208
+ num_query,
1209
+ num_point,
1210
+ grad_value,
1211
+ grad_sampling_loc,
1212
+ grad_attn_weight);
1213
+ break;
1214
+ case 4:
1215
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 4>
1216
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1217
+ 0, stream>>>(
1218
+ num_kernels,
1219
+ grad_col,
1220
+ data_value,
1221
+ data_spatial_shapes,
1222
+ data_level_start_index,
1223
+ data_sampling_loc,
1224
+ data_attn_weight,
1225
+ batch_size,
1226
+ spatial_size,
1227
+ num_heads,
1228
+ channels,
1229
+ num_levels,
1230
+ num_query,
1231
+ num_point,
1232
+ grad_value,
1233
+ grad_sampling_loc,
1234
+ grad_attn_weight);
1235
+ break;
1236
+ case 8:
1237
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 8>
1238
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1239
+ 0, stream>>>(
1240
+ num_kernels,
1241
+ grad_col,
1242
+ data_value,
1243
+ data_spatial_shapes,
1244
+ data_level_start_index,
1245
+ data_sampling_loc,
1246
+ data_attn_weight,
1247
+ batch_size,
1248
+ spatial_size,
1249
+ num_heads,
1250
+ channels,
1251
+ num_levels,
1252
+ num_query,
1253
+ num_point,
1254
+ grad_value,
1255
+ grad_sampling_loc,
1256
+ grad_attn_weight);
1257
+ break;
1258
+ case 16:
1259
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 16>
1260
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1261
+ 0, stream>>>(
1262
+ num_kernels,
1263
+ grad_col,
1264
+ data_value,
1265
+ data_spatial_shapes,
1266
+ data_level_start_index,
1267
+ data_sampling_loc,
1268
+ data_attn_weight,
1269
+ batch_size,
1270
+ spatial_size,
1271
+ num_heads,
1272
+ channels,
1273
+ num_levels,
1274
+ num_query,
1275
+ num_point,
1276
+ grad_value,
1277
+ grad_sampling_loc,
1278
+ grad_attn_weight);
1279
+ break;
1280
+ case 32:
1281
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 32>
1282
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1283
+ 0, stream>>>(
1284
+ num_kernels,
1285
+ grad_col,
1286
+ data_value,
1287
+ data_spatial_shapes,
1288
+ data_level_start_index,
1289
+ data_sampling_loc,
1290
+ data_attn_weight,
1291
+ batch_size,
1292
+ spatial_size,
1293
+ num_heads,
1294
+ channels,
1295
+ num_levels,
1296
+ num_query,
1297
+ num_point,
1298
+ grad_value,
1299
+ grad_sampling_loc,
1300
+ grad_attn_weight);
1301
+ break;
1302
+ case 64:
1303
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 64>
1304
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1305
+ 0, stream>>>(
1306
+ num_kernels,
1307
+ grad_col,
1308
+ data_value,
1309
+ data_spatial_shapes,
1310
+ data_level_start_index,
1311
+ data_sampling_loc,
1312
+ data_attn_weight,
1313
+ batch_size,
1314
+ spatial_size,
1315
+ num_heads,
1316
+ channels,
1317
+ num_levels,
1318
+ num_query,
1319
+ num_point,
1320
+ grad_value,
1321
+ grad_sampling_loc,
1322
+ grad_attn_weight);
1323
+ break;
1324
+ case 128:
1325
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 128>
1326
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1327
+ 0, stream>>>(
1328
+ num_kernels,
1329
+ grad_col,
1330
+ data_value,
1331
+ data_spatial_shapes,
1332
+ data_level_start_index,
1333
+ data_sampling_loc,
1334
+ data_attn_weight,
1335
+ batch_size,
1336
+ spatial_size,
1337
+ num_heads,
1338
+ channels,
1339
+ num_levels,
1340
+ num_query,
1341
+ num_point,
1342
+ grad_value,
1343
+ grad_sampling_loc,
1344
+ grad_attn_weight);
1345
+ break;
1346
+ case 256:
1347
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 256>
1348
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1349
+ 0, stream>>>(
1350
+ num_kernels,
1351
+ grad_col,
1352
+ data_value,
1353
+ data_spatial_shapes,
1354
+ data_level_start_index,
1355
+ data_sampling_loc,
1356
+ data_attn_weight,
1357
+ batch_size,
1358
+ spatial_size,
1359
+ num_heads,
1360
+ channels,
1361
+ num_levels,
1362
+ num_query,
1363
+ num_point,
1364
+ grad_value,
1365
+ grad_sampling_loc,
1366
+ grad_attn_weight);
1367
+ break;
1368
+ case 512:
1369
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 512>
1370
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1371
+ 0, stream>>>(
1372
+ num_kernels,
1373
+ grad_col,
1374
+ data_value,
1375
+ data_spatial_shapes,
1376
+ data_level_start_index,
1377
+ data_sampling_loc,
1378
+ data_attn_weight,
1379
+ batch_size,
1380
+ spatial_size,
1381
+ num_heads,
1382
+ channels,
1383
+ num_levels,
1384
+ num_query,
1385
+ num_point,
1386
+ grad_value,
1387
+ grad_sampling_loc,
1388
+ grad_attn_weight);
1389
+ break;
1390
+ case 1024:
1391
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 1024>
1392
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1393
+ 0, stream>>>(
1394
+ num_kernels,
1395
+ grad_col,
1396
+ data_value,
1397
+ data_spatial_shapes,
1398
+ data_level_start_index,
1399
+ data_sampling_loc,
1400
+ data_attn_weight,
1401
+ batch_size,
1402
+ spatial_size,
1403
+ num_heads,
1404
+ channels,
1405
+ num_levels,
1406
+ num_query,
1407
+ num_point,
1408
+ grad_value,
1409
+ grad_sampling_loc,
1410
+ grad_attn_weight);
1411
+ break;
1412
+ default:
1413
+ if (channels < 64)
1414
+ {
1415
+ ms_deformable_col2im_gpu_kernel_shm_reduce_v1<scalar_t>
1416
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1417
+ num_threads*3*sizeof(scalar_t), stream>>>(
1418
+ num_kernels,
1419
+ grad_col,
1420
+ data_value,
1421
+ data_spatial_shapes,
1422
+ data_level_start_index,
1423
+ data_sampling_loc,
1424
+ data_attn_weight,
1425
+ batch_size,
1426
+ spatial_size,
1427
+ num_heads,
1428
+ channels,
1429
+ num_levels,
1430
+ num_query,
1431
+ num_point,
1432
+ grad_value,
1433
+ grad_sampling_loc,
1434
+ grad_attn_weight);
1435
+ }
1436
+ else
1437
+ {
1438
+ ms_deformable_col2im_gpu_kernel_shm_reduce_v2<scalar_t>
1439
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1440
+ num_threads*3*sizeof(scalar_t), stream>>>(
1441
+ num_kernels,
1442
+ grad_col,
1443
+ data_value,
1444
+ data_spatial_shapes,
1445
+ data_level_start_index,
1446
+ data_sampling_loc,
1447
+ data_attn_weight,
1448
+ batch_size,
1449
+ spatial_size,
1450
+ num_heads,
1451
+ channels,
1452
+ num_levels,
1453
+ num_query,
1454
+ num_point,
1455
+ grad_value,
1456
+ grad_sampling_loc,
1457
+ grad_attn_weight);
1458
+ }
1459
+ }
1460
+ }
1461
+ cudaError_t err = cudaGetLastError();
1462
+ if (err != cudaSuccess)
1463
+ {
1464
+ printf("error in ms_deformable_col2im_cuda: %s\n", cudaGetErrorString(err));
1465
+ }
1466
+
1467
+ }
deformable_detr/ms_deform_attn_cuda.h ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*!
2
+ **************************************************************************************************
3
+ * Deformable DETR
4
+ * Copyright (c) 2020 SenseTime. All Rights Reserved.
5
+ * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ **************************************************************************************************
7
+ * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8
+ **************************************************************************************************
9
+ */
10
+
11
+ #pragma once
12
+ #include <torch/torch.h>
13
+
14
+ at::Tensor ms_deform_attn_cuda_forward(
15
+ const at::Tensor &value,
16
+ const at::Tensor &spatial_shapes,
17
+ const at::Tensor &level_start_index,
18
+ const at::Tensor &sampling_loc,
19
+ const at::Tensor &attn_weight,
20
+ const int im2col_step);
21
+
22
+ at::Tensor ms_deform_attn_cuda_forward_bf16(
23
+ const at::Tensor &value,
24
+ const at::Tensor &spatial_shapes,
25
+ const at::Tensor &level_start_index,
26
+ const at::Tensor &sampling_loc,
27
+ const at::Tensor &attn_weight,
28
+ const int im2col_step);
29
+
30
+ std::vector<at::Tensor> ms_deform_attn_cuda_backward(
31
+ const at::Tensor &value,
32
+ const at::Tensor &spatial_shapes,
33
+ const at::Tensor &level_start_index,
34
+ const at::Tensor &sampling_loc,
35
+ const at::Tensor &attn_weight,
36
+ const at::Tensor &grad_output,
37
+ const int im2col_step);
38
+
39
+ std::vector<at::Tensor> ms_deform_attn_cuda_backward_bf16(
40
+ const at::Tensor &value,
41
+ const at::Tensor &spatial_shapes,
42
+ const at::Tensor &level_start_index,
43
+ const at::Tensor &sampling_loc,
44
+ const at::Tensor &attn_weight,
45
+ const at::Tensor &grad_output,
46
+ const int im2col_step);
deformable_detr/ms_deform_im2col_cuda.cuh ADDED
@@ -0,0 +1,1327 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*!
2
+ **************************************************************************
3
+ * Deformable DETR
4
+ * Copyright (c) 2020 SenseTime. All Rights Reserved.
5
+ * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ **************************************************************************
7
+ * Modified from DCN (https://github.com/msracver/Deformable-ConvNets)
8
+ * Copyright (c) 2018 Microsoft
9
+ **************************************************************************
10
+ */
11
+
12
+ #include <cstdio>
13
+ #include <algorithm>
14
+ #include <cstring>
15
+
16
+ #include <ATen/ATen.h>
17
+ #include <ATen/cuda/CUDAContext.h>
18
+
19
+ #include <THC/THCAtomics.cuh>
20
+
21
+ #define CUDA_KERNEL_LOOP(i, n) \
22
+ for (int i = blockIdx.x * blockDim.x + threadIdx.x; \
23
+ i < (n); \
24
+ i += blockDim.x * gridDim.x)
25
+
26
+ const int CUDA_NUM_THREADS = 1024;
27
+ inline int GET_BLOCKS(const int N, const int num_threads)
28
+ {
29
+ return (N + num_threads - 1) / num_threads;
30
+ }
31
+
32
+
33
+ template <typename scalar_t>
34
+ __device__ scalar_t ms_deform_attn_im2col_bilinear(const scalar_t* &bottom_data,
35
+ const int &height, const int &width, const int &nheads, const int &channels,
36
+ const scalar_t &h, const scalar_t &w, const int &m, const int &c)
37
+ {
38
+ const int h_low = floor(h);
39
+ const int w_low = floor(w);
40
+ const int h_high = h_low + 1;
41
+ const int w_high = w_low + 1;
42
+
43
+ const scalar_t lh = h - h_low;
44
+ const scalar_t lw = w - w_low;
45
+ const scalar_t hh = 1 - lh, hw = 1 - lw;
46
+
47
+ const int w_stride = nheads * channels;
48
+ const int h_stride = width * w_stride;
49
+ const int h_low_ptr_offset = h_low * h_stride;
50
+ const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
51
+ const int w_low_ptr_offset = w_low * w_stride;
52
+ const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
53
+ const int base_ptr = m * channels + c;
54
+
55
+ scalar_t v1 = 0;
56
+ if (h_low >= 0 && w_low >= 0)
57
+ {
58
+ const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
59
+ v1 = bottom_data[ptr1];
60
+ }
61
+ scalar_t v2 = 0;
62
+ if (h_low >= 0 && w_high <= width - 1)
63
+ {
64
+ const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
65
+ v2 = bottom_data[ptr2];
66
+ }
67
+ scalar_t v3 = 0;
68
+ if (h_high <= height - 1 && w_low >= 0)
69
+ {
70
+ const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
71
+ v3 = bottom_data[ptr3];
72
+ }
73
+ scalar_t v4 = 0;
74
+ if (h_high <= height - 1 && w_high <= width - 1)
75
+ {
76
+ const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
77
+ v4 = bottom_data[ptr4];
78
+ }
79
+
80
+ const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
81
+
82
+ const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
83
+ return val;
84
+ }
85
+
86
+
87
+ template <typename scalar_t>
88
+ __device__ void ms_deform_attn_col2im_bilinear(const scalar_t* &bottom_data,
89
+ const int &height, const int &width, const int &nheads, const int &channels,
90
+ const scalar_t &h, const scalar_t &w, const int &m, const int &c,
91
+ const scalar_t &top_grad,
92
+ const scalar_t &attn_weight,
93
+ scalar_t* &grad_value,
94
+ scalar_t* grad_sampling_loc,
95
+ scalar_t* grad_attn_weight)
96
+ {
97
+ const int h_low = floor(h);
98
+ const int w_low = floor(w);
99
+ const int h_high = h_low + 1;
100
+ const int w_high = w_low + 1;
101
+
102
+ const scalar_t lh = h - h_low;
103
+ const scalar_t lw = w - w_low;
104
+ const scalar_t hh = 1 - lh, hw = 1 - lw;
105
+
106
+ const int w_stride = nheads * channels;
107
+ const int h_stride = width * w_stride;
108
+ const int h_low_ptr_offset = h_low * h_stride;
109
+ const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
110
+ const int w_low_ptr_offset = w_low * w_stride;
111
+ const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
112
+ const int base_ptr = m * channels + c;
113
+
114
+ const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
115
+ const scalar_t top_grad_value = top_grad * attn_weight;
116
+ scalar_t grad_h_weight = 0, grad_w_weight = 0;
117
+
118
+ scalar_t v1 = 0;
119
+ if (h_low >= 0 && w_low >= 0)
120
+ {
121
+ const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
122
+ v1 = bottom_data[ptr1];
123
+ grad_h_weight -= hw * v1;
124
+ grad_w_weight -= hh * v1;
125
+ atomicAdd(grad_value+ptr1, w1*top_grad_value);
126
+ }
127
+ scalar_t v2 = 0;
128
+ if (h_low >= 0 && w_high <= width - 1)
129
+ {
130
+ const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
131
+ v2 = bottom_data[ptr2];
132
+ grad_h_weight -= lw * v2;
133
+ grad_w_weight += hh * v2;
134
+ atomicAdd(grad_value+ptr2, w2*top_grad_value);
135
+ }
136
+ scalar_t v3 = 0;
137
+ if (h_high <= height - 1 && w_low >= 0)
138
+ {
139
+ const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
140
+ v3 = bottom_data[ptr3];
141
+ grad_h_weight += hw * v3;
142
+ grad_w_weight -= lh * v3;
143
+ atomicAdd(grad_value+ptr3, w3*top_grad_value);
144
+ }
145
+ scalar_t v4 = 0;
146
+ if (h_high <= height - 1 && w_high <= width - 1)
147
+ {
148
+ const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
149
+ v4 = bottom_data[ptr4];
150
+ grad_h_weight += lw * v4;
151
+ grad_w_weight += lh * v4;
152
+ atomicAdd(grad_value+ptr4, w4*top_grad_value);
153
+ }
154
+
155
+ const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
156
+ *grad_attn_weight = top_grad * val;
157
+ *grad_sampling_loc = width * grad_w_weight * top_grad_value;
158
+ *(grad_sampling_loc + 1) = height * grad_h_weight * top_grad_value;
159
+ }
160
+
161
+
162
+ template <typename scalar_t>
163
+ __device__ void ms_deform_attn_col2im_bilinear_gm(const scalar_t* &bottom_data,
164
+ const int &height, const int &width, const int &nheads, const int &channels,
165
+ const scalar_t &h, const scalar_t &w, const int &m, const int &c,
166
+ const scalar_t &top_grad,
167
+ const scalar_t &attn_weight,
168
+ scalar_t* &grad_value,
169
+ scalar_t* grad_sampling_loc,
170
+ scalar_t* grad_attn_weight)
171
+ {
172
+ const int h_low = floor(h);
173
+ const int w_low = floor(w);
174
+ const int h_high = h_low + 1;
175
+ const int w_high = w_low + 1;
176
+
177
+ const scalar_t lh = h - h_low;
178
+ const scalar_t lw = w - w_low;
179
+ const scalar_t hh = 1 - lh, hw = 1 - lw;
180
+
181
+ const int w_stride = nheads * channels;
182
+ const int h_stride = width * w_stride;
183
+ const int h_low_ptr_offset = h_low * h_stride;
184
+ const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
185
+ const int w_low_ptr_offset = w_low * w_stride;
186
+ const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
187
+ const int base_ptr = m * channels + c;
188
+
189
+ const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
190
+ const scalar_t top_grad_value = top_grad * attn_weight;
191
+ scalar_t grad_h_weight = 0, grad_w_weight = 0;
192
+
193
+ scalar_t v1 = 0;
194
+ if (h_low >= 0 && w_low >= 0)
195
+ {
196
+ const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
197
+ v1 = bottom_data[ptr1];
198
+ grad_h_weight -= hw * v1;
199
+ grad_w_weight -= hh * v1;
200
+ atomicAdd(grad_value+ptr1, w1*top_grad_value);
201
+ }
202
+ scalar_t v2 = 0;
203
+ if (h_low >= 0 && w_high <= width - 1)
204
+ {
205
+ const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
206
+ v2 = bottom_data[ptr2];
207
+ grad_h_weight -= lw * v2;
208
+ grad_w_weight += hh * v2;
209
+ atomicAdd(grad_value+ptr2, w2*top_grad_value);
210
+ }
211
+ scalar_t v3 = 0;
212
+ if (h_high <= height - 1 && w_low >= 0)
213
+ {
214
+ const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
215
+ v3 = bottom_data[ptr3];
216
+ grad_h_weight += hw * v3;
217
+ grad_w_weight -= lh * v3;
218
+ atomicAdd(grad_value+ptr3, w3*top_grad_value);
219
+ }
220
+ scalar_t v4 = 0;
221
+ if (h_high <= height - 1 && w_high <= width - 1)
222
+ {
223
+ const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
224
+ v4 = bottom_data[ptr4];
225
+ grad_h_weight += lw * v4;
226
+ grad_w_weight += lh * v4;
227
+ atomicAdd(grad_value+ptr4, w4*top_grad_value);
228
+ }
229
+
230
+ const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
231
+ atomicAdd(grad_attn_weight, top_grad * val);
232
+ atomicAdd(grad_sampling_loc, width * grad_w_weight * top_grad_value);
233
+ atomicAdd(grad_sampling_loc + 1, height * grad_h_weight * top_grad_value);
234
+ }
235
+
236
+
237
+ template <typename scalar_t>
238
+ __global__ void ms_deformable_im2col_gpu_kernel(const int n,
239
+ const scalar_t *data_value,
240
+ const int64_t *data_spatial_shapes,
241
+ const int64_t *data_level_start_index,
242
+ const scalar_t *data_sampling_loc,
243
+ const scalar_t *data_attn_weight,
244
+ const int batch_size,
245
+ const int spatial_size,
246
+ const int num_heads,
247
+ const int channels,
248
+ const int num_levels,
249
+ const int num_query,
250
+ const int num_point,
251
+ scalar_t *data_col)
252
+ {
253
+ CUDA_KERNEL_LOOP(index, n)
254
+ {
255
+ int _temp = index;
256
+ const int c_col = _temp % channels;
257
+ _temp /= channels;
258
+ const int sampling_index = _temp;
259
+ const int m_col = _temp % num_heads;
260
+ _temp /= num_heads;
261
+ [[maybe_unused]] const int q_col = _temp % num_query;
262
+ _temp /= num_query;
263
+ const int b_col = _temp;
264
+
265
+ scalar_t *data_col_ptr = data_col + index;
266
+ int data_weight_ptr = sampling_index * num_levels * num_point;
267
+ int data_loc_w_ptr = data_weight_ptr << 1;
268
+ const int qid_stride = num_heads * channels;
269
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
270
+ scalar_t col = 0;
271
+
272
+ for (int l_col=0; l_col < num_levels; ++l_col)
273
+ {
274
+ const int level_start_id = data_level_start_index[l_col];
275
+ const int spatial_h_ptr = l_col << 1;
276
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
277
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
278
+ const scalar_t *data_value_ptr = data_value + (data_value_ptr_init_offset + level_start_id * qid_stride);
279
+ for (int p_col=0; p_col < num_point; ++p_col)
280
+ {
281
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
282
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
283
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
284
+
285
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
286
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
287
+
288
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
289
+ {
290
+ col += ms_deform_attn_im2col_bilinear(data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col) * weight;
291
+ }
292
+
293
+ data_weight_ptr += 1;
294
+ data_loc_w_ptr += 2;
295
+ }
296
+ }
297
+ *data_col_ptr = col;
298
+ }
299
+ }
300
+
301
+ template <typename scalar_t, unsigned int blockSize>
302
+ __global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1(const int n,
303
+ const scalar_t *grad_col,
304
+ const scalar_t *data_value,
305
+ const int64_t *data_spatial_shapes,
306
+ const int64_t *data_level_start_index,
307
+ const scalar_t *data_sampling_loc,
308
+ const scalar_t *data_attn_weight,
309
+ const int batch_size,
310
+ const int spatial_size,
311
+ const int num_heads,
312
+ const int channels,
313
+ const int num_levels,
314
+ const int num_query,
315
+ const int num_point,
316
+ scalar_t *grad_value,
317
+ scalar_t *grad_sampling_loc,
318
+ scalar_t *grad_attn_weight)
319
+ {
320
+ CUDA_KERNEL_LOOP(index, n)
321
+ {
322
+ __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2];
323
+ __shared__ scalar_t cache_grad_attn_weight[blockSize];
324
+ unsigned int tid = threadIdx.x;
325
+ int _temp = index;
326
+ const int c_col = _temp % channels;
327
+ _temp /= channels;
328
+ const int sampling_index = _temp;
329
+ const int m_col = _temp % num_heads;
330
+ _temp /= num_heads;
331
+ [[maybe_unused]] const int q_col = _temp % num_query;
332
+ _temp /= num_query;
333
+ const int b_col = _temp;
334
+
335
+ const scalar_t top_grad = grad_col[index];
336
+
337
+ int data_weight_ptr = sampling_index * num_levels * num_point;
338
+ int data_loc_w_ptr = data_weight_ptr << 1;
339
+ const int grad_sampling_ptr = data_weight_ptr;
340
+ grad_sampling_loc += grad_sampling_ptr << 1;
341
+ grad_attn_weight += grad_sampling_ptr;
342
+ const int grad_weight_stride = 1;
343
+ const int grad_loc_stride = 2;
344
+ const int qid_stride = num_heads * channels;
345
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
346
+
347
+ for (int l_col=0; l_col < num_levels; ++l_col)
348
+ {
349
+ const int level_start_id = data_level_start_index[l_col];
350
+ const int spatial_h_ptr = l_col << 1;
351
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
352
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
353
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
354
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
355
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
356
+
357
+ for (int p_col=0; p_col < num_point; ++p_col)
358
+ {
359
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
360
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
361
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
362
+
363
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
364
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
365
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
366
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
367
+ *(cache_grad_attn_weight+threadIdx.x)=0;
368
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
369
+ {
370
+ ms_deform_attn_col2im_bilinear(
371
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
372
+ top_grad, weight, grad_value_ptr,
373
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
374
+ }
375
+
376
+ __syncthreads();
377
+ if (tid == 0)
378
+ {
379
+ scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0];
380
+ int sid=2;
381
+ for (unsigned int tid = 1; tid < blockSize; ++tid)
382
+ {
383
+ _grad_w += cache_grad_sampling_loc[sid];
384
+ _grad_h += cache_grad_sampling_loc[sid + 1];
385
+ _grad_a += cache_grad_attn_weight[tid];
386
+ sid += 2;
387
+ }
388
+
389
+
390
+ *grad_sampling_loc = _grad_w;
391
+ *(grad_sampling_loc + 1) = _grad_h;
392
+ *grad_attn_weight = _grad_a;
393
+ }
394
+ __syncthreads();
395
+
396
+ data_weight_ptr += 1;
397
+ data_loc_w_ptr += 2;
398
+ grad_attn_weight += grad_weight_stride;
399
+ grad_sampling_loc += grad_loc_stride;
400
+ }
401
+ }
402
+ }
403
+ }
404
+
405
+
406
+ template <typename scalar_t, unsigned int blockSize>
407
+ __global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2(const int n,
408
+ const scalar_t *grad_col,
409
+ const scalar_t *data_value,
410
+ const int64_t *data_spatial_shapes,
411
+ const int64_t *data_level_start_index,
412
+ const scalar_t *data_sampling_loc,
413
+ const scalar_t *data_attn_weight,
414
+ const int batch_size,
415
+ const int spatial_size,
416
+ const int num_heads,
417
+ const int channels,
418
+ const int num_levels,
419
+ const int num_query,
420
+ const int num_point,
421
+ scalar_t *grad_value,
422
+ scalar_t *grad_sampling_loc,
423
+ scalar_t *grad_attn_weight)
424
+ {
425
+ CUDA_KERNEL_LOOP(index, n)
426
+ {
427
+ __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2];
428
+ __shared__ scalar_t cache_grad_attn_weight[blockSize];
429
+ unsigned int tid = threadIdx.x;
430
+ int _temp = index;
431
+ const int c_col = _temp % channels;
432
+ _temp /= channels;
433
+ const int sampling_index = _temp;
434
+ const int m_col = _temp % num_heads;
435
+ _temp /= num_heads;
436
+ [[maybe_unused]] const int q_col = _temp % num_query;
437
+ _temp /= num_query;
438
+ const int b_col = _temp;
439
+
440
+ const scalar_t top_grad = grad_col[index];
441
+
442
+ int data_weight_ptr = sampling_index * num_levels * num_point;
443
+ int data_loc_w_ptr = data_weight_ptr << 1;
444
+ const int grad_sampling_ptr = data_weight_ptr;
445
+ grad_sampling_loc += grad_sampling_ptr << 1;
446
+ grad_attn_weight += grad_sampling_ptr;
447
+ const int grad_weight_stride = 1;
448
+ const int grad_loc_stride = 2;
449
+ const int qid_stride = num_heads * channels;
450
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
451
+
452
+ for (int l_col=0; l_col < num_levels; ++l_col)
453
+ {
454
+ const int level_start_id = data_level_start_index[l_col];
455
+ const int spatial_h_ptr = l_col << 1;
456
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
457
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
458
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
459
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
460
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
461
+
462
+ for (int p_col=0; p_col < num_point; ++p_col)
463
+ {
464
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
465
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
466
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
467
+
468
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
469
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
470
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
471
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
472
+ *(cache_grad_attn_weight+threadIdx.x)=0;
473
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
474
+ {
475
+ ms_deform_attn_col2im_bilinear(
476
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
477
+ top_grad, weight, grad_value_ptr,
478
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
479
+ }
480
+
481
+ __syncthreads();
482
+
483
+ for (unsigned int s=blockSize/2; s>0; s>>=1)
484
+ {
485
+ if (tid < s) {
486
+ const unsigned int xid1 = tid << 1;
487
+ const unsigned int xid2 = (tid + s) << 1;
488
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
489
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
490
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
491
+ }
492
+ __syncthreads();
493
+ }
494
+
495
+ if (tid == 0)
496
+ {
497
+ *grad_sampling_loc = cache_grad_sampling_loc[0];
498
+ *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1];
499
+ *grad_attn_weight = cache_grad_attn_weight[0];
500
+ }
501
+ __syncthreads();
502
+
503
+ data_weight_ptr += 1;
504
+ data_loc_w_ptr += 2;
505
+ grad_attn_weight += grad_weight_stride;
506
+ grad_sampling_loc += grad_loc_stride;
507
+ }
508
+ }
509
+ }
510
+ }
511
+
512
+
513
+ template <typename scalar_t>
514
+ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v1(const int n,
515
+ const scalar_t *grad_col,
516
+ const scalar_t *data_value,
517
+ const int64_t *data_spatial_shapes,
518
+ const int64_t *data_level_start_index,
519
+ const scalar_t *data_sampling_loc,
520
+ const scalar_t *data_attn_weight,
521
+ const int batch_size,
522
+ const int spatial_size,
523
+ const int num_heads,
524
+ const int channels,
525
+ const int num_levels,
526
+ const int num_query,
527
+ const int num_point,
528
+ scalar_t *grad_value,
529
+ scalar_t *grad_sampling_loc,
530
+ scalar_t *grad_attn_weight)
531
+ {
532
+ CUDA_KERNEL_LOOP(index, n)
533
+ {
534
+ extern __shared__ int _s[];
535
+ scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
536
+ scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
537
+ unsigned int tid = threadIdx.x;
538
+ int _temp = index;
539
+ const int c_col = _temp % channels;
540
+ _temp /= channels;
541
+ const int sampling_index = _temp;
542
+ const int m_col = _temp % num_heads;
543
+ _temp /= num_heads;
544
+ [[maybe_unused]] const int q_col = _temp % num_query;
545
+ _temp /= num_query;
546
+ const int b_col = _temp;
547
+
548
+ const scalar_t top_grad = grad_col[index];
549
+
550
+ int data_weight_ptr = sampling_index * num_levels * num_point;
551
+ int data_loc_w_ptr = data_weight_ptr << 1;
552
+ const int grad_sampling_ptr = data_weight_ptr;
553
+ grad_sampling_loc += grad_sampling_ptr << 1;
554
+ grad_attn_weight += grad_sampling_ptr;
555
+ const int grad_weight_stride = 1;
556
+ const int grad_loc_stride = 2;
557
+ const int qid_stride = num_heads * channels;
558
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
559
+
560
+ for (int l_col=0; l_col < num_levels; ++l_col)
561
+ {
562
+ const int level_start_id = data_level_start_index[l_col];
563
+ const int spatial_h_ptr = l_col << 1;
564
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
565
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
566
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
567
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
568
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
569
+
570
+ for (int p_col=0; p_col < num_point; ++p_col)
571
+ {
572
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
573
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
574
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
575
+
576
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
577
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
578
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
579
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
580
+ *(cache_grad_attn_weight+threadIdx.x)=0;
581
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
582
+ {
583
+ ms_deform_attn_col2im_bilinear(
584
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
585
+ top_grad, weight, grad_value_ptr,
586
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
587
+ }
588
+
589
+ __syncthreads();
590
+ if (tid == 0)
591
+ {
592
+ scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0];
593
+ int sid=2;
594
+ for (unsigned int tid = 1; tid < blockDim.x; ++tid)
595
+ {
596
+ _grad_w += cache_grad_sampling_loc[sid];
597
+ _grad_h += cache_grad_sampling_loc[sid + 1];
598
+ _grad_a += cache_grad_attn_weight[tid];
599
+ sid += 2;
600
+ }
601
+
602
+
603
+ *grad_sampling_loc = _grad_w;
604
+ *(grad_sampling_loc + 1) = _grad_h;
605
+ *grad_attn_weight = _grad_a;
606
+ }
607
+ __syncthreads();
608
+
609
+ data_weight_ptr += 1;
610
+ data_loc_w_ptr += 2;
611
+ grad_attn_weight += grad_weight_stride;
612
+ grad_sampling_loc += grad_loc_stride;
613
+ }
614
+ }
615
+ }
616
+ }
617
+
618
+ template <typename scalar_t>
619
+ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2(const int n,
620
+ const scalar_t *grad_col,
621
+ const scalar_t *data_value,
622
+ const int64_t *data_spatial_shapes,
623
+ const int64_t *data_level_start_index,
624
+ const scalar_t *data_sampling_loc,
625
+ const scalar_t *data_attn_weight,
626
+ const int batch_size,
627
+ const int spatial_size,
628
+ const int num_heads,
629
+ const int channels,
630
+ const int num_levels,
631
+ const int num_query,
632
+ const int num_point,
633
+ scalar_t *grad_value,
634
+ scalar_t *grad_sampling_loc,
635
+ scalar_t *grad_attn_weight)
636
+ {
637
+ CUDA_KERNEL_LOOP(index, n)
638
+ {
639
+ extern __shared__ int _s[];
640
+ scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
641
+ scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
642
+ unsigned int tid = threadIdx.x;
643
+ int _temp = index;
644
+ const int c_col = _temp % channels;
645
+ _temp /= channels;
646
+ const int sampling_index = _temp;
647
+ const int m_col = _temp % num_heads;
648
+ _temp /= num_heads;
649
+ [[maybe_unused]] const int q_col = _temp % num_query;
650
+ _temp /= num_query;
651
+ const int b_col = _temp;
652
+
653
+ const scalar_t top_grad = grad_col[index];
654
+
655
+ int data_weight_ptr = sampling_index * num_levels * num_point;
656
+ int data_loc_w_ptr = data_weight_ptr << 1;
657
+ const int grad_sampling_ptr = data_weight_ptr;
658
+ grad_sampling_loc += grad_sampling_ptr << 1;
659
+ grad_attn_weight += grad_sampling_ptr;
660
+ const int grad_weight_stride = 1;
661
+ const int grad_loc_stride = 2;
662
+ const int qid_stride = num_heads * channels;
663
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
664
+
665
+ for (int l_col=0; l_col < num_levels; ++l_col)
666
+ {
667
+ const int level_start_id = data_level_start_index[l_col];
668
+ const int spatial_h_ptr = l_col << 1;
669
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
670
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
671
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
672
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
673
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
674
+
675
+ for (int p_col=0; p_col < num_point; ++p_col)
676
+ {
677
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
678
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
679
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
680
+
681
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
682
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
683
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
684
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
685
+ *(cache_grad_attn_weight+threadIdx.x)=0;
686
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
687
+ {
688
+ ms_deform_attn_col2im_bilinear(
689
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
690
+ top_grad, weight, grad_value_ptr,
691
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
692
+ }
693
+
694
+ __syncthreads();
695
+
696
+ for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1)
697
+ {
698
+ if (tid < s) {
699
+ const unsigned int xid1 = tid << 1;
700
+ const unsigned int xid2 = (tid + s) << 1;
701
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
702
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
703
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
704
+ if (tid + (s << 1) < spre)
705
+ {
706
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)];
707
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)];
708
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)];
709
+ }
710
+ }
711
+ __syncthreads();
712
+ }
713
+
714
+ if (tid == 0)
715
+ {
716
+ *grad_sampling_loc = cache_grad_sampling_loc[0];
717
+ *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1];
718
+ *grad_attn_weight = cache_grad_attn_weight[0];
719
+ }
720
+ __syncthreads();
721
+
722
+ data_weight_ptr += 1;
723
+ data_loc_w_ptr += 2;
724
+ grad_attn_weight += grad_weight_stride;
725
+ grad_sampling_loc += grad_loc_stride;
726
+ }
727
+ }
728
+ }
729
+ }
730
+
731
+ template <typename scalar_t>
732
+ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks(const int n,
733
+ const scalar_t *grad_col,
734
+ const scalar_t *data_value,
735
+ const int64_t *data_spatial_shapes,
736
+ const int64_t *data_level_start_index,
737
+ const scalar_t *data_sampling_loc,
738
+ const scalar_t *data_attn_weight,
739
+ const int batch_size,
740
+ const int spatial_size,
741
+ const int num_heads,
742
+ const int channels,
743
+ const int num_levels,
744
+ const int num_query,
745
+ const int num_point,
746
+ scalar_t *grad_value,
747
+ scalar_t *grad_sampling_loc,
748
+ scalar_t *grad_attn_weight)
749
+ {
750
+ CUDA_KERNEL_LOOP(index, n)
751
+ {
752
+ extern __shared__ int _s[];
753
+ scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
754
+ scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
755
+ unsigned int tid = threadIdx.x;
756
+ int _temp = index;
757
+ const int c_col = _temp % channels;
758
+ _temp /= channels;
759
+ const int sampling_index = _temp;
760
+ const int m_col = _temp % num_heads;
761
+ _temp /= num_heads;
762
+ [[maybe_unused]] const int q_col = _temp % num_query;
763
+ _temp /= num_query;
764
+ const int b_col = _temp;
765
+
766
+ const scalar_t top_grad = grad_col[index];
767
+
768
+ int data_weight_ptr = sampling_index * num_levels * num_point;
769
+ int data_loc_w_ptr = data_weight_ptr << 1;
770
+ const int grad_sampling_ptr = data_weight_ptr;
771
+ grad_sampling_loc += grad_sampling_ptr << 1;
772
+ grad_attn_weight += grad_sampling_ptr;
773
+ const int grad_weight_stride = 1;
774
+ const int grad_loc_stride = 2;
775
+ const int qid_stride = num_heads * channels;
776
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
777
+
778
+ for (int l_col=0; l_col < num_levels; ++l_col)
779
+ {
780
+ const int level_start_id = data_level_start_index[l_col];
781
+ const int spatial_h_ptr = l_col << 1;
782
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
783
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
784
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
785
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
786
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
787
+
788
+ for (int p_col=0; p_col < num_point; ++p_col)
789
+ {
790
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
791
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
792
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
793
+
794
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
795
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
796
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
797
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
798
+ *(cache_grad_attn_weight+threadIdx.x)=0;
799
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
800
+ {
801
+ ms_deform_attn_col2im_bilinear(
802
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
803
+ top_grad, weight, grad_value_ptr,
804
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
805
+ }
806
+
807
+ __syncthreads();
808
+
809
+ for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1)
810
+ {
811
+ if (tid < s) {
812
+ const unsigned int xid1 = tid << 1;
813
+ const unsigned int xid2 = (tid + s) << 1;
814
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
815
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
816
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
817
+ if (tid + (s << 1) < spre)
818
+ {
819
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)];
820
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)];
821
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)];
822
+ }
823
+ }
824
+ __syncthreads();
825
+ }
826
+
827
+ if (tid == 0)
828
+ {
829
+ atomicAdd(grad_sampling_loc, cache_grad_sampling_loc[0]);
830
+ atomicAdd(grad_sampling_loc + 1, cache_grad_sampling_loc[1]);
831
+ atomicAdd(grad_attn_weight, cache_grad_attn_weight[0]);
832
+ }
833
+ __syncthreads();
834
+
835
+ data_weight_ptr += 1;
836
+ data_loc_w_ptr += 2;
837
+ grad_attn_weight += grad_weight_stride;
838
+ grad_sampling_loc += grad_loc_stride;
839
+ }
840
+ }
841
+ }
842
+ }
843
+
844
+
845
+ template <typename scalar_t>
846
+ __global__ void ms_deformable_col2im_gpu_kernel_gm(const int n,
847
+ const scalar_t *grad_col,
848
+ const scalar_t *data_value,
849
+ const int64_t *data_spatial_shapes,
850
+ const int64_t *data_level_start_index,
851
+ const scalar_t *data_sampling_loc,
852
+ const scalar_t *data_attn_weight,
853
+ const int batch_size,
854
+ const int spatial_size,
855
+ const int num_heads,
856
+ const int channels,
857
+ const int num_levels,
858
+ const int num_query,
859
+ const int num_point,
860
+ scalar_t *grad_value,
861
+ scalar_t *grad_sampling_loc,
862
+ scalar_t *grad_attn_weight)
863
+ {
864
+ CUDA_KERNEL_LOOP(index, n)
865
+ {
866
+ int _temp = index;
867
+ const int c_col = _temp % channels;
868
+ _temp /= channels;
869
+ const int sampling_index = _temp;
870
+ const int m_col = _temp % num_heads;
871
+ _temp /= num_heads;
872
+ [[maybe_unused]] const int q_col = _temp % num_query;
873
+ _temp /= num_query;
874
+ const int b_col = _temp;
875
+
876
+ const scalar_t top_grad = grad_col[index];
877
+
878
+ int data_weight_ptr = sampling_index * num_levels * num_point;
879
+ int data_loc_w_ptr = data_weight_ptr << 1;
880
+ const int grad_sampling_ptr = data_weight_ptr;
881
+ grad_sampling_loc += grad_sampling_ptr << 1;
882
+ grad_attn_weight += grad_sampling_ptr;
883
+ const int grad_weight_stride = 1;
884
+ const int grad_loc_stride = 2;
885
+ const int qid_stride = num_heads * channels;
886
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
887
+
888
+ for (int l_col=0; l_col < num_levels; ++l_col)
889
+ {
890
+ const int level_start_id = data_level_start_index[l_col];
891
+ const int spatial_h_ptr = l_col << 1;
892
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
893
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
894
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
895
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
896
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
897
+
898
+ for (int p_col=0; p_col < num_point; ++p_col)
899
+ {
900
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
901
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
902
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
903
+
904
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
905
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
906
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
907
+ {
908
+ ms_deform_attn_col2im_bilinear_gm(
909
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
910
+ top_grad, weight, grad_value_ptr,
911
+ grad_sampling_loc, grad_attn_weight);
912
+ }
913
+ data_weight_ptr += 1;
914
+ data_loc_w_ptr += 2;
915
+ grad_attn_weight += grad_weight_stride;
916
+ grad_sampling_loc += grad_loc_stride;
917
+ }
918
+ }
919
+ }
920
+ }
921
+
922
+
923
+ template <typename scalar_t>
924
+ void ms_deformable_im2col_cuda(cudaStream_t stream,
925
+ const scalar_t* data_value,
926
+ const int64_t* data_spatial_shapes,
927
+ const int64_t* data_level_start_index,
928
+ const scalar_t* data_sampling_loc,
929
+ const scalar_t* data_attn_weight,
930
+ const int batch_size,
931
+ const int spatial_size,
932
+ const int num_heads,
933
+ const int channels,
934
+ const int num_levels,
935
+ const int num_query,
936
+ const int num_point,
937
+ scalar_t* data_col)
938
+ {
939
+ const int num_kernels = batch_size * num_query * num_heads * channels;
940
+ const int num_actual_kernels = batch_size * num_query * num_heads * channels;
941
+ const int num_threads = CUDA_NUM_THREADS;
942
+ ms_deformable_im2col_gpu_kernel<scalar_t>
943
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
944
+ 0, stream>>>(
945
+ num_kernels, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight,
946
+ batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, data_col);
947
+
948
+ cudaError_t err = cudaGetLastError();
949
+ if (err != cudaSuccess)
950
+ {
951
+ printf("error in ms_deformable_im2col_cuda: %s\n", cudaGetErrorString(err));
952
+ }
953
+
954
+ }
955
+
956
+ template <typename scalar_t>
957
+ void ms_deformable_col2im_cuda(cudaStream_t stream,
958
+ const scalar_t* grad_col,
959
+ const scalar_t* data_value,
960
+ const int64_t * data_spatial_shapes,
961
+ const int64_t * data_level_start_index,
962
+ const scalar_t * data_sampling_loc,
963
+ const scalar_t * data_attn_weight,
964
+ const int batch_size,
965
+ const int spatial_size,
966
+ const int num_heads,
967
+ const int channels,
968
+ const int num_levels,
969
+ const int num_query,
970
+ const int num_point,
971
+ scalar_t* grad_value,
972
+ scalar_t* grad_sampling_loc,
973
+ scalar_t* grad_attn_weight)
974
+ {
975
+ const int num_threads = (channels > CUDA_NUM_THREADS)?CUDA_NUM_THREADS:channels;
976
+ const int num_kernels = batch_size * num_query * num_heads * channels;
977
+ const int num_actual_kernels = batch_size * num_query * num_heads * channels;
978
+ if (channels > 1024)
979
+ {
980
+ if ((channels & 1023) == 0)
981
+ {
982
+ ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks<scalar_t>
983
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
984
+ num_threads*3*sizeof(scalar_t), stream>>>(
985
+ num_kernels,
986
+ grad_col,
987
+ data_value,
988
+ data_spatial_shapes,
989
+ data_level_start_index,
990
+ data_sampling_loc,
991
+ data_attn_weight,
992
+ batch_size,
993
+ spatial_size,
994
+ num_heads,
995
+ channels,
996
+ num_levels,
997
+ num_query,
998
+ num_point,
999
+ grad_value,
1000
+ grad_sampling_loc,
1001
+ grad_attn_weight);
1002
+ }
1003
+ else
1004
+ {
1005
+ ms_deformable_col2im_gpu_kernel_gm<scalar_t>
1006
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1007
+ 0, stream>>>(
1008
+ num_kernels,
1009
+ grad_col,
1010
+ data_value,
1011
+ data_spatial_shapes,
1012
+ data_level_start_index,
1013
+ data_sampling_loc,
1014
+ data_attn_weight,
1015
+ batch_size,
1016
+ spatial_size,
1017
+ num_heads,
1018
+ channels,
1019
+ num_levels,
1020
+ num_query,
1021
+ num_point,
1022
+ grad_value,
1023
+ grad_sampling_loc,
1024
+ grad_attn_weight);
1025
+ }
1026
+ }
1027
+ else{
1028
+ switch(channels)
1029
+ {
1030
+ case 1:
1031
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 1>
1032
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1033
+ 0, stream>>>(
1034
+ num_kernels,
1035
+ grad_col,
1036
+ data_value,
1037
+ data_spatial_shapes,
1038
+ data_level_start_index,
1039
+ data_sampling_loc,
1040
+ data_attn_weight,
1041
+ batch_size,
1042
+ spatial_size,
1043
+ num_heads,
1044
+ channels,
1045
+ num_levels,
1046
+ num_query,
1047
+ num_point,
1048
+ grad_value,
1049
+ grad_sampling_loc,
1050
+ grad_attn_weight);
1051
+ break;
1052
+ case 2:
1053
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 2>
1054
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1055
+ 0, stream>>>(
1056
+ num_kernels,
1057
+ grad_col,
1058
+ data_value,
1059
+ data_spatial_shapes,
1060
+ data_level_start_index,
1061
+ data_sampling_loc,
1062
+ data_attn_weight,
1063
+ batch_size,
1064
+ spatial_size,
1065
+ num_heads,
1066
+ channels,
1067
+ num_levels,
1068
+ num_query,
1069
+ num_point,
1070
+ grad_value,
1071
+ grad_sampling_loc,
1072
+ grad_attn_weight);
1073
+ break;
1074
+ case 4:
1075
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 4>
1076
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1077
+ 0, stream>>>(
1078
+ num_kernels,
1079
+ grad_col,
1080
+ data_value,
1081
+ data_spatial_shapes,
1082
+ data_level_start_index,
1083
+ data_sampling_loc,
1084
+ data_attn_weight,
1085
+ batch_size,
1086
+ spatial_size,
1087
+ num_heads,
1088
+ channels,
1089
+ num_levels,
1090
+ num_query,
1091
+ num_point,
1092
+ grad_value,
1093
+ grad_sampling_loc,
1094
+ grad_attn_weight);
1095
+ break;
1096
+ case 8:
1097
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 8>
1098
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1099
+ 0, stream>>>(
1100
+ num_kernels,
1101
+ grad_col,
1102
+ data_value,
1103
+ data_spatial_shapes,
1104
+ data_level_start_index,
1105
+ data_sampling_loc,
1106
+ data_attn_weight,
1107
+ batch_size,
1108
+ spatial_size,
1109
+ num_heads,
1110
+ channels,
1111
+ num_levels,
1112
+ num_query,
1113
+ num_point,
1114
+ grad_value,
1115
+ grad_sampling_loc,
1116
+ grad_attn_weight);
1117
+ break;
1118
+ case 16:
1119
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 16>
1120
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1121
+ 0, stream>>>(
1122
+ num_kernels,
1123
+ grad_col,
1124
+ data_value,
1125
+ data_spatial_shapes,
1126
+ data_level_start_index,
1127
+ data_sampling_loc,
1128
+ data_attn_weight,
1129
+ batch_size,
1130
+ spatial_size,
1131
+ num_heads,
1132
+ channels,
1133
+ num_levels,
1134
+ num_query,
1135
+ num_point,
1136
+ grad_value,
1137
+ grad_sampling_loc,
1138
+ grad_attn_weight);
1139
+ break;
1140
+ case 32:
1141
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 32>
1142
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1143
+ 0, stream>>>(
1144
+ num_kernels,
1145
+ grad_col,
1146
+ data_value,
1147
+ data_spatial_shapes,
1148
+ data_level_start_index,
1149
+ data_sampling_loc,
1150
+ data_attn_weight,
1151
+ batch_size,
1152
+ spatial_size,
1153
+ num_heads,
1154
+ channels,
1155
+ num_levels,
1156
+ num_query,
1157
+ num_point,
1158
+ grad_value,
1159
+ grad_sampling_loc,
1160
+ grad_attn_weight);
1161
+ break;
1162
+ case 64:
1163
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 64>
1164
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1165
+ 0, stream>>>(
1166
+ num_kernels,
1167
+ grad_col,
1168
+ data_value,
1169
+ data_spatial_shapes,
1170
+ data_level_start_index,
1171
+ data_sampling_loc,
1172
+ data_attn_weight,
1173
+ batch_size,
1174
+ spatial_size,
1175
+ num_heads,
1176
+ channels,
1177
+ num_levels,
1178
+ num_query,
1179
+ num_point,
1180
+ grad_value,
1181
+ grad_sampling_loc,
1182
+ grad_attn_weight);
1183
+ break;
1184
+ case 128:
1185
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 128>
1186
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1187
+ 0, stream>>>(
1188
+ num_kernels,
1189
+ grad_col,
1190
+ data_value,
1191
+ data_spatial_shapes,
1192
+ data_level_start_index,
1193
+ data_sampling_loc,
1194
+ data_attn_weight,
1195
+ batch_size,
1196
+ spatial_size,
1197
+ num_heads,
1198
+ channels,
1199
+ num_levels,
1200
+ num_query,
1201
+ num_point,
1202
+ grad_value,
1203
+ grad_sampling_loc,
1204
+ grad_attn_weight);
1205
+ break;
1206
+ case 256:
1207
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 256>
1208
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1209
+ 0, stream>>>(
1210
+ num_kernels,
1211
+ grad_col,
1212
+ data_value,
1213
+ data_spatial_shapes,
1214
+ data_level_start_index,
1215
+ data_sampling_loc,
1216
+ data_attn_weight,
1217
+ batch_size,
1218
+ spatial_size,
1219
+ num_heads,
1220
+ channels,
1221
+ num_levels,
1222
+ num_query,
1223
+ num_point,
1224
+ grad_value,
1225
+ grad_sampling_loc,
1226
+ grad_attn_weight);
1227
+ break;
1228
+ case 512:
1229
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 512>
1230
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1231
+ 0, stream>>>(
1232
+ num_kernels,
1233
+ grad_col,
1234
+ data_value,
1235
+ data_spatial_shapes,
1236
+ data_level_start_index,
1237
+ data_sampling_loc,
1238
+ data_attn_weight,
1239
+ batch_size,
1240
+ spatial_size,
1241
+ num_heads,
1242
+ channels,
1243
+ num_levels,
1244
+ num_query,
1245
+ num_point,
1246
+ grad_value,
1247
+ grad_sampling_loc,
1248
+ grad_attn_weight);
1249
+ break;
1250
+ case 1024:
1251
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 1024>
1252
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1253
+ 0, stream>>>(
1254
+ num_kernels,
1255
+ grad_col,
1256
+ data_value,
1257
+ data_spatial_shapes,
1258
+ data_level_start_index,
1259
+ data_sampling_loc,
1260
+ data_attn_weight,
1261
+ batch_size,
1262
+ spatial_size,
1263
+ num_heads,
1264
+ channels,
1265
+ num_levels,
1266
+ num_query,
1267
+ num_point,
1268
+ grad_value,
1269
+ grad_sampling_loc,
1270
+ grad_attn_weight);
1271
+ break;
1272
+ default:
1273
+ if (channels < 64)
1274
+ {
1275
+ ms_deformable_col2im_gpu_kernel_shm_reduce_v1<scalar_t>
1276
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1277
+ num_threads*3*sizeof(scalar_t), stream>>>(
1278
+ num_kernels,
1279
+ grad_col,
1280
+ data_value,
1281
+ data_spatial_shapes,
1282
+ data_level_start_index,
1283
+ data_sampling_loc,
1284
+ data_attn_weight,
1285
+ batch_size,
1286
+ spatial_size,
1287
+ num_heads,
1288
+ channels,
1289
+ num_levels,
1290
+ num_query,
1291
+ num_point,
1292
+ grad_value,
1293
+ grad_sampling_loc,
1294
+ grad_attn_weight);
1295
+ }
1296
+ else
1297
+ {
1298
+ ms_deformable_col2im_gpu_kernel_shm_reduce_v2<scalar_t>
1299
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1300
+ num_threads*3*sizeof(scalar_t), stream>>>(
1301
+ num_kernels,
1302
+ grad_col,
1303
+ data_value,
1304
+ data_spatial_shapes,
1305
+ data_level_start_index,
1306
+ data_sampling_loc,
1307
+ data_attn_weight,
1308
+ batch_size,
1309
+ spatial_size,
1310
+ num_heads,
1311
+ channels,
1312
+ num_levels,
1313
+ num_query,
1314
+ num_point,
1315
+ grad_value,
1316
+ grad_sampling_loc,
1317
+ grad_attn_weight);
1318
+ }
1319
+ }
1320
+ }
1321
+ cudaError_t err = cudaGetLastError();
1322
+ if (err != cudaSuccess)
1323
+ {
1324
+ printf("error in ms_deformable_col2im_cuda: %s\n", cudaGetErrorString(err));
1325
+ }
1326
+
1327
+ }
flake.nix ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ description = "Flake for deformable_detr kernels";
3
+
4
+ inputs = {
5
+ kernel-builder.url = "git+ssh://[email protected]/huggingface/kernel-builder";
6
+ };
7
+
8
+ outputs =
9
+ {
10
+ self,
11
+ kernel-builder,
12
+ }:
13
+ kernel-builder.lib.genFlakeOutputs ./.;
14
+ }
torch-ext/deformable_detr/__init__.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+ import torch
3
+
4
+ from ._ops import ops
5
+
6
+
7
+ def ms_deform_attn_backward(
8
+ value: torch.Tensor,
9
+ spatial_shapes: torch.Tensor,
10
+ level_start_index: torch.Tensor,
11
+ sampling_loc: torch.Tensor,
12
+ attn_weight: torch.Tensor,
13
+ grad_output: torch.Tensor,
14
+ im2col_step: int,
15
+ ) -> List[torch.Tensor]:
16
+ return ops.ms_deform_attn_backward(
17
+ value,
18
+ spatial_shapes,
19
+ level_start_index,
20
+ sampling_loc,
21
+ attn_weight,
22
+ grad_output,
23
+ im2col_step,
24
+ )
25
+
26
+
27
+ def ms_deform_attn_forward(
28
+ value: torch.Tensor,
29
+ spatial_shapes: torch.Tensor,
30
+ level_start_index: torch.Tensor,
31
+ sampling_loc: torch.Tensor,
32
+ attn_weight: torch.Tensor,
33
+ im2col_step: int,
34
+ ) -> torch.Tensor:
35
+ return ops.ms_deform_attn_forward(
36
+ value,
37
+ spatial_shapes,
38
+ level_start_index,
39
+ sampling_loc,
40
+ attn_weight,
41
+ im2col_step,
42
+ )
43
+
44
+
45
+ __all__ = ["ms_deform_attn_forward", "ms_deform_attn_backward"]
torch-ext/deformable_detr/layers.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Union, Tuple
2
+
3
+ from torch import Tensor
4
+ from torch.autograd import Function
5
+ from torch.autograd.function import once_differentiable
6
+ import torch.nn as nn
7
+
8
+ from ._ops import ops
9
+
10
+
11
+ class MultiScaleDeformableAttentionFunction(Function):
12
+ @staticmethod
13
+ def forward(
14
+ context,
15
+ value: Tensor,
16
+ value_spatial_shapes: Tensor,
17
+ value_level_start_index: Tensor,
18
+ sampling_locations: Tensor,
19
+ attention_weights: Tensor,
20
+ im2col_step: int,
21
+ ):
22
+ context.im2col_step = im2col_step
23
+ output = ops.ms_deform_attn_forward(
24
+ value,
25
+ value_spatial_shapes,
26
+ value_level_start_index,
27
+ sampling_locations,
28
+ attention_weights,
29
+ context.im2col_step,
30
+ )
31
+ context.save_for_backward(
32
+ value,
33
+ value_spatial_shapes,
34
+ value_level_start_index,
35
+ sampling_locations,
36
+ attention_weights,
37
+ )
38
+ return output
39
+
40
+ @staticmethod
41
+ @once_differentiable
42
+ def backward(context, grad_output):
43
+ (
44
+ value,
45
+ value_spatial_shapes,
46
+ value_level_start_index,
47
+ sampling_locations,
48
+ attention_weights,
49
+ ) = context.saved_tensors
50
+ grad_value, grad_sampling_loc, grad_attn_weight = ops.ms_deform_attn_backward(
51
+ value,
52
+ value_spatial_shapes,
53
+ value_level_start_index,
54
+ sampling_locations,
55
+ attention_weights,
56
+ grad_output,
57
+ context.im2col_step,
58
+ )
59
+
60
+ return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None
61
+
62
+
63
+ class MultiScaleDeformableAttention(nn.Module):
64
+ def forward(
65
+ self,
66
+ value: Tensor,
67
+ value_spatial_shapes: Tensor,
68
+ value_spatial_shapes_list: List[Tuple],
69
+ level_start_index: Tensor,
70
+ sampling_locations: Tensor,
71
+ attention_weights: Tensor,
72
+ im2col_step: int,
73
+ ):
74
+ return MultiScaleDeformableAttentionFunction.apply(
75
+ value,
76
+ value_spatial_shapes,
77
+ level_start_index,
78
+ sampling_locations,
79
+ attention_weights,
80
+ im2col_step,
81
+ )
torch-ext/ms_deform_attn_cpu.cpp ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*!
2
+ **************************************************************************************************
3
+ * Deformable DETR
4
+ * Copyright (c) 2020 SenseTime. All Rights Reserved.
5
+ * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ **************************************************************************************************
7
+ * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8
+ **************************************************************************************************
9
+ */
10
+
11
+ #include <vector>
12
+
13
+ #include <ATen/ATen.h>
14
+ #include <ATen/cuda/CUDAContext.h>
15
+
16
+
17
+ at::Tensor
18
+ ms_deform_attn_cpu_forward(
19
+ const at::Tensor &value,
20
+ const at::Tensor &spatial_shapes,
21
+ const at::Tensor &level_start_index,
22
+ const at::Tensor &sampling_loc,
23
+ const at::Tensor &attn_weight,
24
+ const int im2col_step)
25
+ {
26
+ AT_ERROR("Not implement on cpu");
27
+ }
28
+
29
+ std::vector<at::Tensor>
30
+ ms_deform_attn_cpu_backward(
31
+ const at::Tensor &value,
32
+ const at::Tensor &spatial_shapes,
33
+ const at::Tensor &level_start_index,
34
+ const at::Tensor &sampling_loc,
35
+ const at::Tensor &attn_weight,
36
+ const at::Tensor &grad_output,
37
+ const int im2col_step)
38
+ {
39
+ AT_ERROR("Not implement on cpu");
40
+ }
torch-ext/ms_deform_attn_cpu.h ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*!
2
+ **************************************************************************************************
3
+ * Deformable DETR
4
+ * Copyright (c) 2020 SenseTime. All Rights Reserved.
5
+ * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ **************************************************************************************************
7
+ * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8
+ **************************************************************************************************
9
+ */
10
+
11
+ #pragma once
12
+ #include <torch/extension.h>
13
+
14
+ at::Tensor
15
+ ms_deform_attn_cpu_forward(
16
+ const at::Tensor &value,
17
+ const at::Tensor &spatial_shapes,
18
+ const at::Tensor &level_start_index,
19
+ const at::Tensor &sampling_loc,
20
+ const at::Tensor &attn_weight,
21
+ const int im2col_step);
22
+
23
+ std::vector<at::Tensor>
24
+ ms_deform_attn_cpu_backward(
25
+ const at::Tensor &value,
26
+ const at::Tensor &spatial_shapes,
27
+ const at::Tensor &level_start_index,
28
+ const at::Tensor &sampling_loc,
29
+ const at::Tensor &attn_weight,
30
+ const at::Tensor &grad_output,
31
+ const int im2col_step);
32
+
torch-ext/torch_binding.cpp ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/library.h>
2
+
3
+ #include "registration.h"
4
+ #include "torch_binding.h"
5
+
6
+ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
7
+ ops.def("ms_deform_attn_forward(Tensor value, Tensor spatial_shapes,"
8
+ " Tensor level_start_index, Tensor sampling_loc,"
9
+ " Tensor attn_weight, int im2col_step) -> Tensor");
10
+ ops.impl("ms_deform_attn_forward", torch::kCUDA, &ms_deform_attn_cuda_forward);
11
+
12
+ ops.def("ms_deform_attn_backward(Tensor value, Tensor spatial_shapes,"
13
+ " Tensor level_start_index, Tensor sampling_loc,"
14
+ " Tensor attn_weight, Tensor grad_output,"
15
+ " int im2col_step) -> Tensor[]");
16
+ ops.impl("ms_deform_attn_backward", torch::kCUDA, &ms_deform_attn_cuda_backward);
17
+ }
18
+
19
+ REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
torch-ext/torch_binding.h ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <torch/torch.h>
4
+
5
+ at::Tensor ms_deform_attn_cuda_forward(const at::Tensor &value,
6
+ const at::Tensor &spatial_shapes,
7
+ const at::Tensor &level_start_index,
8
+ const at::Tensor &sampling_loc,
9
+ const at::Tensor &attn_weight,
10
+ const int64_t im2col_step);
11
+
12
+ std::vector<at::Tensor> ms_deform_attn_cuda_backward(
13
+ const at::Tensor &value, const at::Tensor &spatial_shapes,
14
+ const at::Tensor &level_start_index, const at::Tensor &sampling_loc,
15
+ const at::Tensor &attn_weight, const at::Tensor &grad_output,
16
+ const int64_t im2col_step);