File size: 10,038 Bytes
6dfcb0f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 |
from functools import partial
import torch.nn as nn
from detectron2.config import LazyCall as L
from detectron2.modeling import ViT
from detectron2.modeling import SimpleFeaturePyramid as BaseSimpleFeaturePyramid
from detectron2.modeling.backbone.fpn import LastLevelMaxPool
from detectron2.layers import CNNBlockBase, Conv2d, get_norm
import sys
sys.path.append('../../')
from modeling_pretrain_cleaned import PretrainVisionTransformer
from detectron2.modeling.backbone.fpn import _assert_strides_are_log2_contiguous
from models.mask_rcnn_fpn_v2 import model, constants
from detectron2.modeling.backbone import Backbone
import torch
import math
import torch.nn.functional as F
import time
model.pixel_mean = constants['imagenet_rgb256_mean']
model.pixel_std = constants['imagenet_rgb256_std']
model.input_format = "RGB"
class ViT(Backbone):
def __init__(self,
img_size=224,
encoder_embed_dim=768,
encoder_depth=12,
encoder_num_heads=12,
encoder_num_classes=0,
decoder_embed_dim=384,
decoder_num_heads=16,
decoder_depth=8,
mlp_ratio=4,
qkv_bias=True,
k_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
patch_size=(8, 8),
num_frames=3,
tubelet_size=1,
use_flash_attention=True,
return_detectron_format=True,
out_feature='last_feat'
):
super().__init__()
self.model = PretrainVisionTransformer( # Single-scale ViT backbone
img_size=img_size,
encoder_embed_dim=encoder_embed_dim,
encoder_depth=encoder_depth,
encoder_num_heads=encoder_num_heads,
encoder_num_classes=encoder_num_classes,
decoder_embed_dim=decoder_embed_dim,
decoder_num_heads=decoder_num_heads,
decoder_depth=decoder_depth,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
k_bias=k_bias,
norm_layer=norm_layer,
patch_size=patch_size,
num_frames=num_frames,
tubelet_size=tubelet_size,
use_flash_attention=use_flash_attention,
return_detectron_format=return_detectron_format,
out_feature=out_feature
)
self._out_features = [out_feature]
self._out_feature_channels = {out_feature: encoder_embed_dim * 2}
self._out_feature_strides = {out_feature: patch_size[0]}
self.patch_hw = 512 // patch_size[0]
self.num_frames = num_frames
pos_embed = self.get_abs_pos(self.model.encoder.pos_embed, num_frames, [self.patch_hw, self.patch_hw])
self.model.encoder.pos_embed = pos_embed[:, 0:self.patch_hw**2 * (self.num_frames - 1), :]
def forward(self, x):
B = x.shape[0]
x = x.unsqueeze(2).expand(-1, -1, self.num_frames-1, -1, -1)
mask = torch.zeros(B, self.patch_hw**2 * (self.num_frames - 1), dtype=torch.bool).to(x.device)
return self.model(x, mask)
def get_abs_pos(self, abs_pos, num_frames, hw):
"""
Calculate absolute positional embeddings. If needed, resize embeddings and remove cls_token
dimension for the original embeddings.
Args:
abs_pos (Tensor): absolute positional embeddings with (1, num_position, C).
has_cls_token (bool): If true, has 1 embedding in abs_pos for cls token.
hw (Tuple): size of input image tokens.
Returns:
Absolute positional embeddings after processing with shape (1, H, W, C)
"""
h, w = hw
xy_num = abs_pos.shape[1] // num_frames
size = int(math.sqrt(xy_num))
assert size * size * num_frames == abs_pos.shape[1]
abs_pos = abs_pos.view(num_frames, xy_num, -1)
if size != h or size != w:
new_abs_pos = torch.nn.functional.interpolate(
abs_pos.reshape(num_frames, size, size, -1).permute(0, 3, 1, 2),
size=(h, w),
mode="bicubic",
align_corners=False,
)
return new_abs_pos.permute(0, 2, 3, 1).flatten(0, 2).unsqueeze(0)
else:
return abs_pos
class SimpleFeaturePyramid(BaseSimpleFeaturePyramid):
"""
This module implements SimpleFeaturePyramid in :paper:`vitdet`.
It creates pyramid features built on top of the input feature map.
"""
def __init__(
self,
net,
in_feature,
out_channels,
scale_factors,
top_block=None,
norm="LN",
square_pad=0,
):
"""
Args:
net (Backbone): module representing the subnetwork backbone.
Must be a subclass of :class:`Backbone`.
in_feature (str): names of the input feature maps coming
from the net.
out_channels (int): number of channels in the output feature maps.
scale_factors (list[float]): list of scaling factors to upsample or downsample
the input features for creating pyramid features.
top_block (nn.Module or None): if provided, an extra operation will
be performed on the output of the last (smallest resolution)
pyramid output, and the result will extend the result list. The top_block
further downsamples the feature map. It must have an attribute
"num_levels", meaning the number of extra pyramid levels added by
this block, and "in_feature", which is a string representing
its input feature (e.g., p5).
norm (str): the normalization to use.
square_pad (int): If > 0, require input images to be padded to specific square size.
"""
super(BaseSimpleFeaturePyramid, self).__init__()
assert isinstance(net, Backbone)
self.scale_factors = scale_factors
input_shapes = net.output_shape()
strides = [int(input_shapes[in_feature].stride / scale) for scale in scale_factors]
_assert_strides_are_log2_contiguous(strides)
dim = input_shapes[in_feature].channels
self.stages = []
use_bias = norm == ""
for idx, scale in enumerate(scale_factors):
out_dim = dim
if scale == 4.0:
layers = [
nn.ConvTranspose2d(dim, dim // 2, kernel_size=2, stride=2),
get_norm(norm, dim // 2),
nn.GELU(),
nn.ConvTranspose2d(dim // 2, dim // 4, kernel_size=2, stride=2),
]
out_dim = dim // 4
elif scale == 2.0:
layers = [nn.ConvTranspose2d(dim, dim // 2, kernel_size=2, stride=2)]
out_dim = dim // 2
elif scale == 1.0:
layers = []
elif scale == 0.5:
layers = [nn.MaxPool2d(kernel_size=2, stride=2)]
elif scale == 0.25:
layers = [nn.MaxPool2d(kernel_size=4, stride=4)]
else:
raise NotImplementedError(f"scale_factor={scale} is not supported yet.")
layers.extend(
[
Conv2d(
out_dim,
out_channels,
kernel_size=1,
bias=use_bias,
norm=get_norm(norm, out_channels),
),
Conv2d(
out_channels,
out_channels,
kernel_size=3,
padding=1,
bias=use_bias,
norm=get_norm(norm, out_channels),
),
]
)
layers = nn.Sequential(*layers)
stage = int(math.log2(strides[idx]))
self.add_module(f"simfp_{stage}", layers)
self.stages.append(layers)
self.net = net
self.in_feature = in_feature
self.top_block = top_block
# Return feature names are "p<stage>", like ["p2", "p3", ..., "p6"]
self._out_feature_strides = {"p{}".format(int(math.log2(s))): s for s in strides}
# top block output feature maps.
if self.top_block is not None:
for s in range(stage, stage + self.top_block.num_levels):
self._out_feature_strides["p{}".format(s + 1)] = 2 ** (s + 1)
self._out_features = list(self._out_feature_strides.keys())
self._out_feature_channels = {k: out_channels for k in self._out_features}
self._size_divisibility = strides[-1]
self._square_pad = square_pad
# Base
embed_dim, depth, num_heads, dp = 768, 12, 12, 0.1
# Creates Simple Feature Pyramid from ViT backbone
model.backbone = L(SimpleFeaturePyramid)(
net=L(ViT)(
img_size=224,
encoder_embed_dim=768,
encoder_depth=12,
encoder_num_heads=12,
encoder_num_classes=0,
decoder_embed_dim=384,
decoder_num_heads=16,
decoder_depth=8,
mlp_ratio=4,
qkv_bias=True,
k_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
patch_size=(16, 16), #(8, 8),
num_frames=3,
tubelet_size=1,
return_detectron_format=True,
use_flash_attention=True,
out_feature='last_feat'
),
in_feature="${.net.out_feature}",
out_channels=256,
scale_factors=(4.0, 2.0, 1.0, 0.5, 0.25),
top_block=L(LastLevelMaxPool)(),
norm="LN",
square_pad=512,
)
model.roi_heads.box_head.conv_norm = model.roi_heads.mask_head.conv_norm = "LN"
# 2conv in RPN:
model.proposal_generator.head.conv_dims = [-1, -1]
# 4conv1fc box head
model.roi_heads.box_head.conv_dims = [256, 256, 256, 256]
model.roi_heads.box_head.fc_dims = [1024] |