|
from typing import List, Union, Tuple |
|
|
|
from torch import Tensor |
|
from torch.autograd import Function |
|
from torch.autograd.function import once_differentiable |
|
import torch.nn as nn |
|
|
|
from ._ops import ops |
|
|
|
|
|
class MultiScaleDeformableAttentionFunction(Function): |
|
@staticmethod |
|
def forward( |
|
context, |
|
value: Tensor, |
|
value_spatial_shapes: Tensor, |
|
value_level_start_index: Tensor, |
|
sampling_locations: Tensor, |
|
attention_weights: Tensor, |
|
im2col_step: int, |
|
): |
|
context.im2col_step = im2col_step |
|
output = ops.ms_deform_attn_forward( |
|
value, |
|
value_spatial_shapes, |
|
value_level_start_index, |
|
sampling_locations, |
|
attention_weights, |
|
context.im2col_step, |
|
) |
|
context.save_for_backward( |
|
value, |
|
value_spatial_shapes, |
|
value_level_start_index, |
|
sampling_locations, |
|
attention_weights, |
|
) |
|
return output |
|
|
|
@staticmethod |
|
@once_differentiable |
|
def backward(context, grad_output): |
|
( |
|
value, |
|
value_spatial_shapes, |
|
value_level_start_index, |
|
sampling_locations, |
|
attention_weights, |
|
) = context.saved_tensors |
|
grad_value, grad_sampling_loc, grad_attn_weight = ops.ms_deform_attn_backward( |
|
value, |
|
value_spatial_shapes, |
|
value_level_start_index, |
|
sampling_locations, |
|
attention_weights, |
|
grad_output, |
|
context.im2col_step, |
|
) |
|
|
|
return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None |
|
|
|
|
|
class MultiScaleDeformableAttention(nn.Module): |
|
def forward( |
|
self, |
|
value: Tensor, |
|
value_spatial_shapes: Tensor, |
|
value_spatial_shapes_list: List[Tuple], |
|
level_start_index: Tensor, |
|
sampling_locations: Tensor, |
|
attention_weights: Tensor, |
|
im2col_step: int, |
|
): |
|
return MultiScaleDeformableAttentionFunction.apply( |
|
value, |
|
value_spatial_shapes, |
|
level_start_index, |
|
sampling_locations, |
|
attention_weights, |
|
im2col_step, |
|
) |
|
|
|
|
|
__all__ = ["MultiScaleDeformableAttention"] |
|
|