|
from typing import List |
|
import torch |
|
|
|
from ._ops import ops |
|
from . import layers |
|
|
|
|
|
def ms_deform_attn_backward( |
|
value: torch.Tensor, |
|
spatial_shapes: torch.Tensor, |
|
level_start_index: torch.Tensor, |
|
sampling_loc: torch.Tensor, |
|
attn_weight: torch.Tensor, |
|
grad_output: torch.Tensor, |
|
im2col_step: int, |
|
) -> List[torch.Tensor]: |
|
return ops.ms_deform_attn_backward( |
|
value, |
|
spatial_shapes, |
|
level_start_index, |
|
sampling_loc, |
|
attn_weight, |
|
grad_output, |
|
im2col_step, |
|
) |
|
|
|
|
|
def ms_deform_attn_forward( |
|
value: torch.Tensor, |
|
spatial_shapes: torch.Tensor, |
|
level_start_index: torch.Tensor, |
|
sampling_loc: torch.Tensor, |
|
attn_weight: torch.Tensor, |
|
im2col_step: int, |
|
) -> torch.Tensor: |
|
return ops.ms_deform_attn_forward( |
|
value, |
|
spatial_shapes, |
|
level_start_index, |
|
sampling_loc, |
|
attn_weight, |
|
im2col_step, |
|
) |
|
|
|
|
|
__all__ = ["layers", "ms_deform_attn_forward", "ms_deform_attn_backward"] |
|
|