|
|
|
#include "../box_iou_rotated/box_iou_rotated_utils.h"
|
|
#include "nms_rotated.h"
|
|
|
|
namespace detectron2 {
|
|
|
|
template <typename scalar_t>
|
|
at::Tensor nms_rotated_cpu_kernel(
|
|
const at::Tensor& dets,
|
|
const at::Tensor& scores,
|
|
const double iou_threshold) {
|
|
|
|
|
|
|
|
|
|
AT_ASSERTM(dets.device().is_cpu(), "dets must be a CPU tensor");
|
|
AT_ASSERTM(scores.device().is_cpu(), "scores must be a CPU tensor");
|
|
AT_ASSERTM(
|
|
dets.scalar_type() == scores.scalar_type(),
|
|
"dets should have the same type as scores");
|
|
|
|
if (dets.numel() == 0) {
|
|
return at::empty({0}, dets.options().dtype(at::kLong));
|
|
}
|
|
|
|
auto order_t = std::get<1>(scores.sort(0, true));
|
|
|
|
auto ndets = dets.size(0);
|
|
at::Tensor suppressed_t = at::zeros({ndets}, dets.options().dtype(at::kByte));
|
|
at::Tensor keep_t = at::zeros({ndets}, dets.options().dtype(at::kLong));
|
|
|
|
auto suppressed = suppressed_t.data_ptr<uint8_t>();
|
|
auto keep = keep_t.data_ptr<int64_t>();
|
|
auto order = order_t.data_ptr<int64_t>();
|
|
|
|
int64_t num_to_keep = 0;
|
|
|
|
for (int64_t _i = 0; _i < ndets; _i++) {
|
|
auto i = order[_i];
|
|
if (suppressed[i] == 1) {
|
|
continue;
|
|
}
|
|
|
|
keep[num_to_keep++] = i;
|
|
|
|
for (int64_t _j = _i + 1; _j < ndets; _j++) {
|
|
auto j = order[_j];
|
|
if (suppressed[j] == 1) {
|
|
continue;
|
|
}
|
|
|
|
auto ovr = single_box_iou_rotated<scalar_t>(
|
|
dets[i].data_ptr<scalar_t>(), dets[j].data_ptr<scalar_t>());
|
|
if (ovr >= iou_threshold) {
|
|
suppressed[j] = 1;
|
|
}
|
|
}
|
|
}
|
|
return keep_t.narrow(0, 0, num_to_keep);
|
|
}
|
|
|
|
at::Tensor nms_rotated_cpu(
|
|
|
|
const at::Tensor& dets,
|
|
const at::Tensor& scores,
|
|
const double iou_threshold) {
|
|
auto result = at::empty({0}, dets.options());
|
|
|
|
AT_DISPATCH_FLOATING_TYPES(dets.scalar_type(), "nms_rotated", [&] {
|
|
result = nms_rotated_cpu_kernel<scalar_t>(dets, scores, iou_threshold);
|
|
});
|
|
return result;
|
|
}
|
|
|
|
}
|
|
|