|
from functools import partial |
|
import torch.nn as nn |
|
from detectron2.config import LazyCall as L |
|
from detectron2.modeling import ViT |
|
from detectron2.modeling import SimpleFeaturePyramid as BaseSimpleFeaturePyramid |
|
from detectron2.modeling.backbone.fpn import LastLevelMaxPool |
|
from detectron2.layers import CNNBlockBase, Conv2d, get_norm |
|
import sys |
|
sys.path.append('../../') |
|
from modeling_pretrain_cleaned import PretrainVisionTransformer |
|
from detectron2.modeling.backbone.fpn import _assert_strides_are_log2_contiguous |
|
from models.mask_rcnn_fpn_v2 import model, constants |
|
from detectron2.modeling.backbone import Backbone |
|
import torch |
|
import math |
|
import torch.nn.functional as F |
|
import time |
|
|
|
model.pixel_mean = constants['imagenet_rgb256_mean'] |
|
model.pixel_std = constants['imagenet_rgb256_std'] |
|
model.input_format = "RGB" |
|
|
|
class ViT(Backbone): |
|
def __init__(self, |
|
img_size=224, |
|
encoder_embed_dim=768, |
|
encoder_depth=12, |
|
encoder_num_heads=12, |
|
encoder_num_classes=0, |
|
decoder_embed_dim=384, |
|
decoder_num_heads=16, |
|
decoder_depth=8, |
|
mlp_ratio=4, |
|
qkv_bias=True, |
|
k_bias=True, |
|
norm_layer=partial(nn.LayerNorm, eps=1e-6), |
|
patch_size=(8, 8), |
|
num_frames=3, |
|
tubelet_size=1, |
|
use_flash_attention=True, |
|
return_detectron_format=True, |
|
out_feature='last_feat' |
|
): |
|
super().__init__() |
|
self.model = PretrainVisionTransformer( |
|
img_size=img_size, |
|
encoder_embed_dim=encoder_embed_dim, |
|
encoder_depth=encoder_depth, |
|
encoder_num_heads=encoder_num_heads, |
|
encoder_num_classes=encoder_num_classes, |
|
decoder_embed_dim=decoder_embed_dim, |
|
decoder_num_heads=decoder_num_heads, |
|
decoder_depth=decoder_depth, |
|
mlp_ratio=mlp_ratio, |
|
qkv_bias=qkv_bias, |
|
k_bias=k_bias, |
|
norm_layer=norm_layer, |
|
patch_size=patch_size, |
|
num_frames=num_frames, |
|
tubelet_size=tubelet_size, |
|
use_flash_attention=use_flash_attention, |
|
return_detectron_format=return_detectron_format, |
|
out_feature=out_feature |
|
) |
|
self._out_features = [out_feature] |
|
self._out_feature_channels = {out_feature: encoder_embed_dim * 2} |
|
self._out_feature_strides = {out_feature: patch_size[0]} |
|
self.patch_hw = 512 // patch_size[0] |
|
self.num_frames = num_frames |
|
pos_embed = self.get_abs_pos(self.model.encoder.pos_embed, num_frames, [self.patch_hw, self.patch_hw]) |
|
self.model.encoder.pos_embed = pos_embed[:, 0:self.patch_hw**2 * (self.num_frames - 1), :] |
|
|
|
def forward(self, x): |
|
B = x.shape[0] |
|
x = x.unsqueeze(2).expand(-1, -1, self.num_frames-1, -1, -1) |
|
mask = torch.zeros(B, self.patch_hw**2 * (self.num_frames - 1), dtype=torch.bool).to(x.device) |
|
return self.model(x, mask) |
|
|
|
def get_abs_pos(self, abs_pos, num_frames, hw): |
|
""" |
|
Calculate absolute positional embeddings. If needed, resize embeddings and remove cls_token |
|
dimension for the original embeddings. |
|
Args: |
|
abs_pos (Tensor): absolute positional embeddings with (1, num_position, C). |
|
has_cls_token (bool): If true, has 1 embedding in abs_pos for cls token. |
|
hw (Tuple): size of input image tokens. |
|
|
|
Returns: |
|
Absolute positional embeddings after processing with shape (1, H, W, C) |
|
""" |
|
|
|
h, w = hw |
|
|
|
xy_num = abs_pos.shape[1] // num_frames |
|
size = int(math.sqrt(xy_num)) |
|
assert size * size * num_frames == abs_pos.shape[1] |
|
abs_pos = abs_pos.view(num_frames, xy_num, -1) |
|
|
|
if size != h or size != w: |
|
new_abs_pos = torch.nn.functional.interpolate( |
|
abs_pos.reshape(num_frames, size, size, -1).permute(0, 3, 1, 2), |
|
size=(h, w), |
|
mode="bicubic", |
|
align_corners=False, |
|
) |
|
|
|
return new_abs_pos.permute(0, 2, 3, 1).flatten(0, 2).unsqueeze(0) |
|
else: |
|
return abs_pos |
|
|
|
|
|
class SimpleFeaturePyramid(BaseSimpleFeaturePyramid): |
|
""" |
|
This module implements SimpleFeaturePyramid in :paper:`vitdet`. |
|
It creates pyramid features built on top of the input feature map. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
net, |
|
in_feature, |
|
out_channels, |
|
scale_factors, |
|
top_block=None, |
|
norm="LN", |
|
square_pad=0, |
|
): |
|
""" |
|
Args: |
|
net (Backbone): module representing the subnetwork backbone. |
|
Must be a subclass of :class:`Backbone`. |
|
in_feature (str): names of the input feature maps coming |
|
from the net. |
|
out_channels (int): number of channels in the output feature maps. |
|
scale_factors (list[float]): list of scaling factors to upsample or downsample |
|
the input features for creating pyramid features. |
|
top_block (nn.Module or None): if provided, an extra operation will |
|
be performed on the output of the last (smallest resolution) |
|
pyramid output, and the result will extend the result list. The top_block |
|
further downsamples the feature map. It must have an attribute |
|
"num_levels", meaning the number of extra pyramid levels added by |
|
this block, and "in_feature", which is a string representing |
|
its input feature (e.g., p5). |
|
norm (str): the normalization to use. |
|
square_pad (int): If > 0, require input images to be padded to specific square size. |
|
""" |
|
super(BaseSimpleFeaturePyramid, self).__init__() |
|
assert isinstance(net, Backbone) |
|
|
|
self.scale_factors = scale_factors |
|
|
|
input_shapes = net.output_shape() |
|
strides = [int(input_shapes[in_feature].stride / scale) for scale in scale_factors] |
|
_assert_strides_are_log2_contiguous(strides) |
|
|
|
dim = input_shapes[in_feature].channels |
|
self.stages = [] |
|
use_bias = norm == "" |
|
for idx, scale in enumerate(scale_factors): |
|
out_dim = dim |
|
if scale == 4.0: |
|
layers = [ |
|
nn.ConvTranspose2d(dim, dim // 2, kernel_size=2, stride=2), |
|
get_norm(norm, dim // 2), |
|
nn.GELU(), |
|
nn.ConvTranspose2d(dim // 2, dim // 4, kernel_size=2, stride=2), |
|
] |
|
out_dim = dim // 4 |
|
elif scale == 2.0: |
|
layers = [nn.ConvTranspose2d(dim, dim // 2, kernel_size=2, stride=2)] |
|
out_dim = dim // 2 |
|
elif scale == 1.0: |
|
layers = [] |
|
elif scale == 0.5: |
|
layers = [nn.MaxPool2d(kernel_size=2, stride=2)] |
|
elif scale == 0.25: |
|
layers = [nn.MaxPool2d(kernel_size=4, stride=4)] |
|
else: |
|
raise NotImplementedError(f"scale_factor={scale} is not supported yet.") |
|
|
|
layers.extend( |
|
[ |
|
Conv2d( |
|
out_dim, |
|
out_channels, |
|
kernel_size=1, |
|
bias=use_bias, |
|
norm=get_norm(norm, out_channels), |
|
), |
|
Conv2d( |
|
out_channels, |
|
out_channels, |
|
kernel_size=3, |
|
padding=1, |
|
bias=use_bias, |
|
norm=get_norm(norm, out_channels), |
|
), |
|
] |
|
) |
|
layers = nn.Sequential(*layers) |
|
|
|
stage = int(math.log2(strides[idx])) |
|
self.add_module(f"simfp_{stage}", layers) |
|
self.stages.append(layers) |
|
|
|
self.net = net |
|
self.in_feature = in_feature |
|
self.top_block = top_block |
|
|
|
self._out_feature_strides = {"p{}".format(int(math.log2(s))): s for s in strides} |
|
|
|
if self.top_block is not None: |
|
for s in range(stage, stage + self.top_block.num_levels): |
|
self._out_feature_strides["p{}".format(s + 1)] = 2 ** (s + 1) |
|
|
|
self._out_features = list(self._out_feature_strides.keys()) |
|
self._out_feature_channels = {k: out_channels for k in self._out_features} |
|
self._size_divisibility = strides[-1] |
|
self._square_pad = square_pad |
|
|
|
|
|
|
|
embed_dim, depth, num_heads, dp = 768, 12, 12, 0.1 |
|
|
|
model.backbone = L(SimpleFeaturePyramid)( |
|
net=L(ViT)( |
|
img_size=224, |
|
encoder_embed_dim=768, |
|
encoder_depth=12, |
|
encoder_num_heads=12, |
|
encoder_num_classes=0, |
|
decoder_embed_dim=384, |
|
decoder_num_heads=16, |
|
decoder_depth=8, |
|
mlp_ratio=4, |
|
qkv_bias=True, |
|
k_bias=True, |
|
norm_layer=partial(nn.LayerNorm, eps=1e-6), |
|
patch_size=(16, 16), |
|
num_frames=3, |
|
tubelet_size=1, |
|
return_detectron_format=True, |
|
use_flash_attention=True, |
|
out_feature='last_feat' |
|
), |
|
in_feature="${.net.out_feature}", |
|
out_channels=256, |
|
scale_factors=(4.0, 2.0, 1.0, 0.5, 0.25), |
|
top_block=L(LastLevelMaxPool)(), |
|
norm="LN", |
|
square_pad=512, |
|
) |
|
|
|
model.roi_heads.box_head.conv_norm = model.roi_heads.mask_head.conv_norm = "LN" |
|
|
|
|
|
model.proposal_generator.head.conv_dims = [-1, -1] |
|
|
|
|
|
model.roi_heads.box_head.conv_dims = [256, 256, 256, 256] |
|
model.roi_heads.box_head.fc_dims = [1024] |