File size: 750 Bytes
cae2c48 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 |
#pragma once
#include <torch/torch.h>
at::Tensor ms_deform_attn_cuda_forward(const at::Tensor &value,
const at::Tensor &spatial_shapes,
const at::Tensor &level_start_index,
const at::Tensor &sampling_loc,
const at::Tensor &attn_weight,
const int64_t im2col_step);
std::vector<at::Tensor> ms_deform_attn_cuda_backward(
const at::Tensor &value, const at::Tensor &spatial_shapes,
const at::Tensor &level_start_index, const at::Tensor &sampling_loc,
const at::Tensor &attn_weight, const at::Tensor &grad_output,
const int64_t im2col_step);
|