File size: 2,307 Bytes
7b77420 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 |
from typing import List, Union, Tuple
from torch import Tensor
from torch.autograd import Function
from torch.autograd.function import once_differentiable
import torch.nn as nn
from ._ops import ops
class MultiScaleDeformableAttentionFunction(Function):
@staticmethod
def forward(
context,
value: Tensor,
value_spatial_shapes: Tensor,
value_level_start_index: Tensor,
sampling_locations: Tensor,
attention_weights: Tensor,
im2col_step: int,
):
context.im2col_step = im2col_step
output = ops.ms_deform_attn_forward(
value,
value_spatial_shapes,
value_level_start_index,
sampling_locations,
attention_weights,
context.im2col_step,
)
context.save_for_backward(
value,
value_spatial_shapes,
value_level_start_index,
sampling_locations,
attention_weights,
)
return output
@staticmethod
@once_differentiable
def backward(context, grad_output):
(
value,
value_spatial_shapes,
value_level_start_index,
sampling_locations,
attention_weights,
) = context.saved_tensors
grad_value, grad_sampling_loc, grad_attn_weight = ops.ms_deform_attn_backward(
value,
value_spatial_shapes,
value_level_start_index,
sampling_locations,
attention_weights,
grad_output,
context.im2col_step,
)
return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None
class MultiScaleDeformableAttention(nn.Module):
def forward(
self,
value: Tensor,
value_spatial_shapes: Tensor,
value_spatial_shapes_list: List[Tuple],
level_start_index: Tensor,
sampling_locations: Tensor,
attention_weights: Tensor,
im2col_step: int,
):
return MultiScaleDeformableAttentionFunction.apply(
value,
value_spatial_shapes,
level_start_index,
sampling_locations,
attention_weights,
im2col_step,
)
__all__ = ["MultiScaleDeformableAttention"]
|