|
|
|
#include <ATen/ATen.h>
|
|
#include <ATen/cuda/CUDAContext.h>
|
|
#include <c10/cuda/CUDAGuard.h>
|
|
#include <ATen/cuda/CUDAApplyUtils.cuh>
|
|
#ifdef WITH_CUDA
|
|
#include "../box_iou_rotated/box_iou_rotated_utils.h"
|
|
#endif
|
|
|
|
#ifdef WITH_HIP
|
|
#include "box_iou_rotated/box_iou_rotated_utils.h"
|
|
#endif
|
|
|
|
using namespace detectron2;
|
|
|
|
namespace {
|
|
int const threadsPerBlock = sizeof(unsigned long long) * 8;
|
|
}
|
|
|
|
template <typename T>
|
|
__global__ void nms_rotated_cuda_kernel(
|
|
const int n_boxes,
|
|
const double iou_threshold,
|
|
const T* dev_boxes,
|
|
unsigned long long* dev_mask) {
|
|
|
|
|
|
const int row_start = blockIdx.y;
|
|
const int col_start = blockIdx.x;
|
|
|
|
|
|
|
|
const int row_size =
|
|
min(n_boxes - row_start * threadsPerBlock, threadsPerBlock);
|
|
const int col_size =
|
|
min(n_boxes - col_start * threadsPerBlock, threadsPerBlock);
|
|
|
|
|
|
|
|
|
|
__shared__ T block_boxes[threadsPerBlock * 5];
|
|
if (threadIdx.x < col_size) {
|
|
block_boxes[threadIdx.x * 5 + 0] =
|
|
dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 0];
|
|
block_boxes[threadIdx.x * 5 + 1] =
|
|
dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 1];
|
|
block_boxes[threadIdx.x * 5 + 2] =
|
|
dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 2];
|
|
block_boxes[threadIdx.x * 5 + 3] =
|
|
dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 3];
|
|
block_boxes[threadIdx.x * 5 + 4] =
|
|
dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 4];
|
|
}
|
|
__syncthreads();
|
|
|
|
if (threadIdx.x < row_size) {
|
|
const int cur_box_idx = threadsPerBlock * row_start + threadIdx.x;
|
|
const T* cur_box = dev_boxes + cur_box_idx * 5;
|
|
int i = 0;
|
|
unsigned long long t = 0;
|
|
int start = 0;
|
|
if (row_start == col_start) {
|
|
start = threadIdx.x + 1;
|
|
}
|
|
for (i = start; i < col_size; i++) {
|
|
|
|
|
|
if (single_box_iou_rotated<T>(cur_box, block_boxes + i * 5) >
|
|
iou_threshold) {
|
|
t |= 1ULL << i;
|
|
}
|
|
}
|
|
const int col_blocks = at::cuda::ATenCeilDiv(n_boxes, threadsPerBlock);
|
|
dev_mask[cur_box_idx * col_blocks + col_start] = t;
|
|
}
|
|
}
|
|
|
|
namespace detectron2 {
|
|
|
|
at::Tensor nms_rotated_cuda(
|
|
|
|
const at::Tensor& dets,
|
|
const at::Tensor& scores,
|
|
double iou_threshold) {
|
|
|
|
AT_ASSERTM(dets.is_cuda(), "dets must be a CUDA tensor");
|
|
AT_ASSERTM(scores.is_cuda(), "scores must be a CUDA tensor");
|
|
at::cuda::CUDAGuard device_guard(dets.device());
|
|
|
|
auto order_t = std::get<1>(scores.sort(0, true));
|
|
auto dets_sorted = dets.index_select(0, order_t);
|
|
|
|
auto dets_num = dets.size(0);
|
|
|
|
const int col_blocks =
|
|
at::cuda::ATenCeilDiv(static_cast<int>(dets_num), threadsPerBlock);
|
|
|
|
at::Tensor mask =
|
|
at::empty({dets_num * col_blocks}, dets.options().dtype(at::kLong));
|
|
|
|
dim3 blocks(col_blocks, col_blocks);
|
|
dim3 threads(threadsPerBlock);
|
|
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
|
|
|
AT_DISPATCH_FLOATING_TYPES(
|
|
dets_sorted.scalar_type(), "nms_rotated_kernel_cuda", [&] {
|
|
nms_rotated_cuda_kernel<scalar_t><<<blocks, threads, 0, stream>>>(
|
|
dets_num,
|
|
iou_threshold,
|
|
dets_sorted.data_ptr<scalar_t>(),
|
|
(unsigned long long*)mask.data_ptr<int64_t>());
|
|
});
|
|
|
|
at::Tensor mask_cpu = mask.to(at::kCPU);
|
|
unsigned long long* mask_host =
|
|
(unsigned long long*)mask_cpu.data_ptr<int64_t>();
|
|
|
|
std::vector<unsigned long long> remv(col_blocks);
|
|
memset(&remv[0], 0, sizeof(unsigned long long) * col_blocks);
|
|
|
|
at::Tensor keep =
|
|
at::empty({dets_num}, dets.options().dtype(at::kLong).device(at::kCPU));
|
|
int64_t* keep_out = keep.data_ptr<int64_t>();
|
|
|
|
int num_to_keep = 0;
|
|
for (int i = 0; i < dets_num; i++) {
|
|
int nblock = i / threadsPerBlock;
|
|
int inblock = i % threadsPerBlock;
|
|
|
|
if (!(remv[nblock] & (1ULL << inblock))) {
|
|
keep_out[num_to_keep++] = i;
|
|
unsigned long long* p = mask_host + i * col_blocks;
|
|
for (int j = nblock; j < col_blocks; j++) {
|
|
remv[j] |= p[j];
|
|
}
|
|
}
|
|
}
|
|
|
|
AT_CUDA_CHECK(cudaGetLastError());
|
|
return order_t.index(
|
|
{keep.narrow(0, 0, num_to_keep)
|
|
.to(order_t.device(), keep.scalar_type())});
|
|
}
|
|
|
|
}
|
|
|