danieldk's picture
danieldk HF staff
Expose layer
98affba
from typing import List
import torch
from ._ops import ops
from . import layers
def ms_deform_attn_backward(
value: torch.Tensor,
spatial_shapes: torch.Tensor,
level_start_index: torch.Tensor,
sampling_loc: torch.Tensor,
attn_weight: torch.Tensor,
grad_output: torch.Tensor,
im2col_step: int,
) -> List[torch.Tensor]:
return ops.ms_deform_attn_backward(
value,
spatial_shapes,
level_start_index,
sampling_loc,
attn_weight,
grad_output,
im2col_step,
)
def ms_deform_attn_forward(
value: torch.Tensor,
spatial_shapes: torch.Tensor,
level_start_index: torch.Tensor,
sampling_loc: torch.Tensor,
attn_weight: torch.Tensor,
im2col_step: int,
) -> torch.Tensor:
return ops.ms_deform_attn_forward(
value,
spatial_shapes,
level_start_index,
sampling_loc,
attn_weight,
im2col_step,
)
__all__ = ["layers", "ms_deform_attn_forward", "ms_deform_attn_backward"]