|
|
|
#pragma once |
|
#include <torch/types.h> |
|
|
|
namespace detectron2 { |
|
|
|
at::Tensor nms_rotated_cpu( |
|
const at::Tensor& dets, |
|
const at::Tensor& scores, |
|
const double iou_threshold); |
|
|
|
#if defined(WITH_CUDA) || defined(WITH_HIP) |
|
at::Tensor nms_rotated_cuda( |
|
const at::Tensor& dets, |
|
const at::Tensor& scores, |
|
const double iou_threshold); |
|
#endif |
|
|
|
|
|
|
|
|
|
inline at::Tensor nms_rotated( |
|
const at::Tensor& dets, |
|
const at::Tensor& scores, |
|
const double iou_threshold) { |
|
assert(dets.device().is_cuda() == scores.device().is_cuda()); |
|
if (dets.device().is_cuda()) { |
|
#if defined(WITH_CUDA) || defined(WITH_HIP) |
|
return nms_rotated_cuda( |
|
dets.contiguous(), scores.contiguous(), iou_threshold); |
|
#else |
|
AT_ERROR("Detectron2 is not compiled with GPU support!"); |
|
#endif |
|
} |
|
|
|
return nms_rotated_cpu(dets.contiguous(), scores.contiguous(), iou_threshold); |
|
} |
|
|
|
} |
|
|