Spaces:
Build error
Build error
File size: 7,303 Bytes
414b431 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 |
# This source code is written based on https://github.com/facebookresearch/MCC
# The original code base is licensed under the license found in the LICENSE file in the root directory.
import torch
import torch.nn as nn
import torchvision
from functools import partial
from timm.models.vision_transformer import Block
from utils.pos_embed import get_2d_sincos_pos_embed
from utils.layers import Bottleneck_Conv
class CoordEmb(nn.Module):
"""
Encode the seen coordinate map to a lower resolution feature map
Achieved with window-wise attention block by deviding coord map into windows
Each window is seperately encoded into a single CLS token with self-attention and posenc
"""
def __init__(self, embed_dim, win_size=8, num_heads=8):
super().__init__()
self.embed_dim = embed_dim
self.win_size = win_size
self.two_d_pos_embed = nn.Parameter(
torch.zeros(1, self.win_size*self.win_size + 1, embed_dim), requires_grad=False) # fixed sin-cos embedding
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.pos_embed = nn.Linear(3, embed_dim)
self.blocks = nn.ModuleList([
# each block is a residual block with layernorm -> attention -> layernorm -> mlp
Block(embed_dim, num_heads=num_heads, mlp_ratio=2.0, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6))
for _ in range(1)
])
self.invalid_coord_token = nn.Parameter(torch.zeros(embed_dim,))
self.initialize_weights()
def initialize_weights(self):
torch.nn.init.normal_(self.cls_token, std=.02)
two_d_pos_embed = get_2d_sincos_pos_embed(self.two_d_pos_embed.shape[-1], self.win_size, cls_token=True)
self.two_d_pos_embed.data.copy_(torch.from_numpy(two_d_pos_embed).float().unsqueeze(0))
torch.nn.init.normal_(self.invalid_coord_token, std=.02)
def forward(self, coord_obj, mask_obj):
# [B, H, W, C]
emb = self.pos_embed(coord_obj)
emb[~mask_obj] = 0.0
emb[~mask_obj] += self.invalid_coord_token
B, H, W, C = emb.shape
# [B, H/ws, 8, W/ws, W, C]
emb = emb.view(B, H // self.win_size, self.win_size, W // self.win_size, self.win_size, C)
# [B * H/ws * W/ws, 64, C]
emb = emb.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, self.win_size * self.win_size, C)
# [B * H/ws * W/ws, 64, C], add posenc that is local to each patch
emb = emb + self.two_d_pos_embed[:, 1:, :]
# [1, 1, C]
cls_token = self.cls_token + self.two_d_pos_embed[:, :1, :]
# [B * H/ws * W/ws, 1, C]
cls_tokens = cls_token.expand(emb.shape[0], -1, -1)
# [B * H/ws * W/ws, 65, C]
emb = torch.cat((cls_tokens, emb), dim=1)
# transformer (single block) that handle each of the patch seperately
# reasoning is done within each batch
for _, blk in enumerate(self.blocks):
emb = blk(emb)
# return the cls token of each window, [B, H/ws*W/ws, C]
return emb[:, 0].view(B, (H // self.win_size) * (W // self.win_size), -1)
class CoordEncAtt(nn.Module):
"""
Seen surface encoder based on transformer.
"""
def __init__(self,
embed_dim=768, n_blocks=12, num_heads=12, win_size=8,
mlp_ratio=4., norm_layer=partial(nn.LayerNorm, eps=1e-6), drop_path=0.1):
super().__init__()
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.coord_embed = CoordEmb(embed_dim, win_size, num_heads)
self.blocks = nn.ModuleList([
Block(
embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer,
drop_path=drop_path
) for _ in range(n_blocks)])
self.norm = norm_layer(embed_dim)
self.initialize_weights()
def initialize_weights(self):
# timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
torch.nn.init.normal_(self.cls_token, std=.02)
# initialize nn.Linear and nn.LayerNorm
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
# we use xavier_uniform following official JAX ViT:
torch.nn.init.xavier_uniform_(m.weight)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward(self, coord_obj, mask_obj):
# [B, H/ws*W/ws, C]
coord_embedding = self.coord_embed(coord_obj, mask_obj)
# append cls token
# [1, 1, C]
cls_token = self.cls_token
# [B, 1, C]
cls_tokens = cls_token.expand(coord_embedding.shape[0], -1, -1)
# [B, H/ws*W/ws+1, C]
coord_embedding = torch.cat((cls_tokens, coord_embedding), dim=1)
# apply Transformer blocks
for blk in self.blocks:
coord_embedding = blk(coord_embedding)
coord_embedding = self.norm(coord_embedding)
# [B, H/ws*W/ws+1, C]
return coord_embedding
class CoordEncRes(nn.Module):
"""
Seen surface encoder based on resnet.
"""
def __init__(self, opt):
super().__init__()
self.encoder = torchvision.models.resnet50(pretrained=True)
self.encoder.fc = nn.Sequential(
Bottleneck_Conv(2048),
Bottleneck_Conv(2048),
nn.Linear(2048, opt.arch.latent_dim)
)
# define hooks
self.seen_feature = None
def feature_hook(model, input, output):
self.seen_feature = output
# attach hooks
assert opt.arch.depth.dsp == 1
if (opt.arch.win_size) == 16:
self.encoder.layer3.register_forward_hook(feature_hook)
self.depth_feat_proj = nn.Sequential(
Bottleneck_Conv(1024),
Bottleneck_Conv(1024),
nn.Conv2d(1024, opt.arch.latent_dim, 1)
)
elif (opt.arch.win_size) == 32:
self.encoder.layer4.register_forward_hook(feature_hook)
self.depth_feat_proj = nn.Sequential(
Bottleneck_Conv(2048),
Bottleneck_Conv(2048),
nn.Conv2d(2048, opt.arch.latent_dim, 1)
)
else:
print('Make sure win_size is 16 or 32 when using resnet backbone!')
raise NotImplementedError
def forward(self, coord_obj, mask_obj):
batch_size = coord_obj.shape[0]
assert len(coord_obj.shape) == len(mask_obj.shape) == 4
mask_obj = mask_obj.float()
coord_obj = coord_obj * mask_obj
# [B, 1, C]
global_feat = self.encoder(coord_obj).unsqueeze(1)
# [B, C, H/ws*W/ws]
local_feat = self.depth_feat_proj(self.seen_feature).view(batch_size, global_feat.shape[-1], -1)
# [B, H/ws*W/ws, C]
local_feat = local_feat.permute(0, 2, 1).contiguous()
# [B, 1+H/ws*W/ws, C]
seen_embedding = torch.cat([global_feat, local_feat], dim=1)
return seen_embedding |