Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
from typing import Optional, Tuple, Union | |
import torch | |
from mmcv.cnn import build_norm_layer | |
from mmcv.cnn.bricks.transformer import FFN, MultiheadAttention | |
from mmcv.ops import MultiScaleDeformableAttention | |
from mmengine.model import ModuleList | |
from torch import Tensor, nn | |
from .detr_layers import (DetrTransformerDecoder, DetrTransformerDecoderLayer, | |
DetrTransformerEncoder, DetrTransformerEncoderLayer) | |
from .utils import inverse_sigmoid | |
class DeformableDetrTransformerEncoder(DetrTransformerEncoder): | |
"""Transformer encoder of Deformable DETR.""" | |
def _init_layers(self) -> None: | |
"""Initialize encoder layers.""" | |
self.layers = ModuleList([ | |
DeformableDetrTransformerEncoderLayer(**self.layer_cfg) | |
for _ in range(self.num_layers) | |
]) | |
self.embed_dims = self.layers[0].embed_dims | |
def forward(self, query: Tensor, query_pos: Tensor, | |
key_padding_mask: Tensor, spatial_shapes: Tensor, | |
level_start_index: Tensor, valid_ratios: Tensor, | |
**kwargs) -> Tensor: | |
"""Forward function of Transformer encoder. | |
Args: | |
query (Tensor): The input query, has shape (bs, num_queries, dim). | |
query_pos (Tensor): The positional encoding for query, has shape | |
(bs, num_queries, dim). | |
key_padding_mask (Tensor): The `key_padding_mask` of `self_attn` | |
input. ByteTensor, has shape (bs, num_queries). | |
spatial_shapes (Tensor): Spatial shapes of features in all levels, | |
has shape (num_levels, 2), last dimension represents (h, w). | |
level_start_index (Tensor): The start index of each level. | |
A tensor has shape (num_levels, ) and can be represented | |
as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...]. | |
valid_ratios (Tensor): The ratios of the valid width and the valid | |
height relative to the width and the height of features in all | |
levels, has shape (bs, num_levels, 2). | |
Returns: | |
Tensor: Output queries of Transformer encoder, which is also | |
called 'encoder output embeddings' or 'memory', has shape | |
(bs, num_queries, dim) | |
""" | |
reference_points = self.get_encoder_reference_points( | |
spatial_shapes, valid_ratios, device=query.device) | |
for layer in self.layers: | |
query = layer( | |
query=query, | |
query_pos=query_pos, | |
key_padding_mask=key_padding_mask, | |
spatial_shapes=spatial_shapes, | |
level_start_index=level_start_index, | |
valid_ratios=valid_ratios, | |
reference_points=reference_points, | |
**kwargs) | |
return query | |
def get_encoder_reference_points( | |
spatial_shapes: Tensor, valid_ratios: Tensor, | |
device: Union[torch.device, str]) -> Tensor: | |
"""Get the reference points used in encoder. | |
Args: | |
spatial_shapes (Tensor): Spatial shapes of features in all levels, | |
has shape (num_levels, 2), last dimension represents (h, w). | |
valid_ratios (Tensor): The ratios of the valid width and the valid | |
height relative to the width and the height of features in all | |
levels, has shape (bs, num_levels, 2). | |
device (obj:`device` or str): The device acquired by the | |
`reference_points`. | |
Returns: | |
Tensor: Reference points used in decoder, has shape (bs, length, | |
num_levels, 2). | |
""" | |
reference_points_list = [] | |
for lvl, (H, W) in enumerate(spatial_shapes): | |
ref_y, ref_x = torch.meshgrid( | |
torch.linspace( | |
0.5, H - 0.5, H, dtype=torch.float32, device=device), | |
torch.linspace( | |
0.5, W - 0.5, W, dtype=torch.float32, device=device)) | |
ref_y = ref_y.reshape(-1)[None] / ( | |
valid_ratios[:, None, lvl, 1] * H) | |
ref_x = ref_x.reshape(-1)[None] / ( | |
valid_ratios[:, None, lvl, 0] * W) | |
ref = torch.stack((ref_x, ref_y), -1) | |
reference_points_list.append(ref) | |
reference_points = torch.cat(reference_points_list, 1) | |
# [bs, sum(hw), num_level, 2] | |
reference_points = reference_points[:, :, None] * valid_ratios[:, None] | |
return reference_points | |
class DeformableDetrTransformerDecoder(DetrTransformerDecoder): | |
"""Transformer Decoder of Deformable DETR.""" | |
def _init_layers(self) -> None: | |
"""Initialize decoder layers.""" | |
self.layers = ModuleList([ | |
DeformableDetrTransformerDecoderLayer(**self.layer_cfg) | |
for _ in range(self.num_layers) | |
]) | |
self.embed_dims = self.layers[0].embed_dims | |
if self.post_norm_cfg is not None: | |
raise ValueError('There is not post_norm in ' | |
f'{self._get_name()}') | |
def forward(self, | |
query: Tensor, | |
query_pos: Tensor, | |
value: Tensor, | |
key_padding_mask: Tensor, | |
reference_points: Tensor, | |
spatial_shapes: Tensor, | |
level_start_index: Tensor, | |
valid_ratios: Tensor, | |
reg_branches: Optional[nn.Module] = None, | |
**kwargs) -> Tuple[Tensor]: | |
"""Forward function of Transformer decoder. | |
Args: | |
query (Tensor): The input queries, has shape (bs, num_queries, | |
dim). | |
query_pos (Tensor): The input positional query, has shape | |
(bs, num_queries, dim). It will be added to `query` before | |
forward function. | |
value (Tensor): The input values, has shape (bs, num_value, dim). | |
key_padding_mask (Tensor): The `key_padding_mask` of `cross_attn` | |
input. ByteTensor, has shape (bs, num_value). | |
reference_points (Tensor): The initial reference, has shape | |
(bs, num_queries, 4) with the last dimension arranged as | |
(cx, cy, w, h) when `as_two_stage` is `True`, otherwise has | |
shape (bs, num_queries, 2) with the last dimension arranged | |
as (cx, cy). | |
spatial_shapes (Tensor): Spatial shapes of features in all levels, | |
has shape (num_levels, 2), last dimension represents (h, w). | |
level_start_index (Tensor): The start index of each level. | |
A tensor has shape (num_levels, ) and can be represented | |
as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...]. | |
valid_ratios (Tensor): The ratios of the valid width and the valid | |
height relative to the width and the height of features in all | |
levels, has shape (bs, num_levels, 2). | |
reg_branches: (obj:`nn.ModuleList`, optional): Used for refining | |
the regression results. Only would be passed when | |
`with_box_refine` is `True`, otherwise would be `None`. | |
Returns: | |
tuple[Tensor]: Outputs of Deformable Transformer Decoder. | |
- output (Tensor): Output embeddings of the last decoder, has | |
shape (num_queries, bs, embed_dims) when `return_intermediate` | |
is `False`. Otherwise, Intermediate output embeddings of all | |
decoder layers, has shape (num_decoder_layers, num_queries, bs, | |
embed_dims). | |
- reference_points (Tensor): The reference of the last decoder | |
layer, has shape (bs, num_queries, 4) when `return_intermediate` | |
is `False`. Otherwise, Intermediate references of all decoder | |
layers, has shape (num_decoder_layers, bs, num_queries, 4). The | |
coordinates are arranged as (cx, cy, w, h) | |
""" | |
output = query | |
intermediate = [] | |
intermediate_reference_points = [] | |
for layer_id, layer in enumerate(self.layers): | |
if reference_points.shape[-1] == 4: | |
reference_points_input = \ | |
reference_points[:, :, None] * \ | |
torch.cat([valid_ratios, valid_ratios], -1)[:, None] | |
else: | |
assert reference_points.shape[-1] == 2 | |
reference_points_input = \ | |
reference_points[:, :, None] * \ | |
valid_ratios[:, None] | |
output = layer( | |
output, | |
query_pos=query_pos, | |
value=value, | |
key_padding_mask=key_padding_mask, | |
spatial_shapes=spatial_shapes, | |
level_start_index=level_start_index, | |
valid_ratios=valid_ratios, | |
reference_points=reference_points_input, | |
**kwargs) | |
if reg_branches is not None: | |
tmp_reg_preds = reg_branches[layer_id](output) | |
if reference_points.shape[-1] == 4: | |
new_reference_points = tmp_reg_preds + inverse_sigmoid( | |
reference_points) | |
new_reference_points = new_reference_points.sigmoid() | |
else: | |
assert reference_points.shape[-1] == 2 | |
new_reference_points = tmp_reg_preds | |
new_reference_points[..., :2] = tmp_reg_preds[ | |
..., :2] + inverse_sigmoid(reference_points) | |
new_reference_points = new_reference_points.sigmoid() | |
reference_points = new_reference_points.detach() | |
if self.return_intermediate: | |
intermediate.append(output) | |
intermediate_reference_points.append(reference_points) | |
if self.return_intermediate: | |
return torch.stack(intermediate), torch.stack( | |
intermediate_reference_points) | |
return output, reference_points | |
class DeformableDetrTransformerEncoderLayer(DetrTransformerEncoderLayer): | |
"""Encoder layer of Deformable DETR.""" | |
def _init_layers(self) -> None: | |
"""Initialize self_attn, ffn, and norms.""" | |
self.self_attn = MultiScaleDeformableAttention(**self.self_attn_cfg) | |
self.embed_dims = self.self_attn.embed_dims | |
self.ffn = FFN(**self.ffn_cfg) | |
norms_list = [ | |
build_norm_layer(self.norm_cfg, self.embed_dims)[1] | |
for _ in range(2) | |
] | |
self.norms = ModuleList(norms_list) | |
class DeformableDetrTransformerDecoderLayer(DetrTransformerDecoderLayer): | |
"""Decoder layer of Deformable DETR.""" | |
def _init_layers(self) -> None: | |
"""Initialize self_attn, cross-attn, ffn, and norms.""" | |
self.self_attn = MultiheadAttention(**self.self_attn_cfg) | |
self.cross_attn = MultiScaleDeformableAttention(**self.cross_attn_cfg) | |
self.embed_dims = self.self_attn.embed_dims | |
self.ffn = FFN(**self.ffn_cfg) | |
norms_list = [ | |
build_norm_layer(self.norm_cfg, self.embed_dims)[1] | |
for _ in range(3) | |
] | |
self.norms = ModuleList(norms_list) | |