Spaces:
Configuration error
Configuration error
import math | |
import torch.nn.init as init | |
from timm.models.registry import register_model | |
from timm.models.layers import DropPath | |
from .tag_layers import * | |
class PatchEmbed(nn.Module): | |
def __init__(self, stride, has_mask=False, in_ch=0, out_ch=0): | |
super(PatchEmbed, self).__init__() | |
self.to_token = nn.Conv2d(in_ch, in_ch, kernel_size=3, padding=1, stride=stride, groups=in_ch) | |
self.proj = nn.Linear(in_ch, out_ch, bias=False) | |
self.has_mask = has_mask | |
def process_mask(self, x, mask, H, W): | |
if mask is None and self.has_mask: | |
mask = x.new_zeros((1, 1, H, W)) | |
if mask is not None: | |
H_mask, W_mask = mask.shape[-2:] | |
if H_mask != H or W_mask != W: | |
mask = F.interpolate(mask, (H, W), mode='nearest') | |
return mask | |
def forward(self, x, mask): | |
""" | |
Args: | |
x: [B, C, H, W] | |
mask: [B, 1, H, W] if exists, else None | |
Returns: | |
out: [B, out_H * out_W, out_C] | |
H, W: output height & width | |
mask: [B, 1, out_H, out_W] if exists, else None | |
""" | |
out = self.to_token(x) | |
B, C, H, W = out.shape | |
mask = self.process_mask(out, mask, H, W) | |
out = rearrange(out, "b c h w -> b (h w) c").contiguous() | |
out = self.proj(out) | |
return out, H, W, mask | |
class Encoder(nn.Module): | |
def __init__(self, dim, num_parts=64, num_enc_heads=1, drop_path=0.1, act=nn.GELU, has_ffn=True): | |
super(Encoder, self).__init__() | |
self.num_heads = num_enc_heads | |
self.enc_attn = AnyAttention(dim, num_enc_heads) | |
self.drop_path = DropPath(drop_prob=drop_path) if drop_path else nn.Identity() | |
self.reason = SimpleReasoning(num_parts, dim) | |
self.enc_ffn = Mlp(dim, hidden_features=dim, act_layer=act) if has_ffn else None | |
def forward(self, feats, parts=None, qpos=None, kpos=None, mask=None): | |
""" | |
Args: | |
feats: [B, patch_num * patch_size, C] | |
parts: [B, N, C] | |
qpos: [B, N, 1, C] | |
kpos: [B, patch_num * patch_size, C] | |
mask: [B, 1, patch_num, patch_size] if exists, else None | |
Returns: | |
parts: [B, N, C] | |
""" | |
attn_out = self.enc_attn(q=parts, k=feats, v=feats, qpos=qpos, kpos=kpos, mask=mask) | |
parts = parts + self.drop_path(attn_out) | |
parts = self.reason(parts) | |
if self.enc_ffn is not None: | |
parts = parts + self.drop_path(self.enc_ffn(parts)) | |
return parts | |
class Decoder(nn.Module): | |
def __init__(self, dim, num_heads=8, patch_size=7, ffn_exp=3, act=nn.GELU, drop_path=0.1): | |
super().__init__() | |
assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." | |
self.dim = dim | |
self.num_heads = num_heads | |
self.attn1 = AnyAttention(dim, num_heads) | |
self.attn2 = AnyAttention(dim, num_heads) | |
self.rel_pos = FullRelPos(patch_size, patch_size, dim // num_heads) | |
self.ffn1 = Mlp(dim, hidden_features=dim * ffn_exp, act_layer=act, norm_layer=Norm) | |
self.ffn2 = Mlp(dim, hidden_features=dim * ffn_exp, act_layer=act, norm_layer=Norm) | |
self.drop_path = DropPath(drop_path) | |
def forward(self, x, parts=None, qpos=None, kpos=None, mask=None, P=0): | |
""" | |
Args: | |
x: [B, patch_num * patch_size, C] | |
parts: [B, N, C] | |
part_kpos: [B, N, 1, C] | |
mask: [B, 1, patch_num, patch_size] if exists, else None | |
P: patch_num | |
Returns: | |
feat: [B, patch_num, patch_size, C] | |
""" | |
dec_mask = None if mask is None else rearrange(mask.squeeze(1), "b h w -> b (h w) 1 1") | |
out = self.attn1(q=x, k=parts, v=parts, qpos=qpos, kpos=kpos, mask=dec_mask) | |
out = x + self.drop_path(out) | |
out = out + self.drop_path(self.ffn1(out)) | |
# out = rearrange(out, "b (p k) c -> (b p) k c", p=P) | |
# local_out = self.attn2(q=out, k=out, v=out, mask=mask, rel_pos=self.rel_pos) | |
# out = out + self.drop_path(local_out) | |
# out = out + self.drop_path(self.ffn2(out)) | |
return rearrange(out, "b (p k) c -> b p k c", p=P) | |
class TAGBlock(nn.Module): | |
def __init__(self, dim, ffn_exp=4, drop_path=0.1, patch_size=7, num_heads=1, num_enc_heads=1, num_parts=0): | |
super(TAGBlock, self).__init__() | |
# self.encoder = Encoder(dim, num_parts=num_parts, num_enc_heads=num_enc_heads, drop_path=drop_path) | |
self.decoder = Decoder(dim, num_heads=num_heads, patch_size=patch_size, ffn_exp=ffn_exp, drop_path=drop_path) | |
def forward(self, x, parts=None, qpos=None, kpos=None, mask=None): | |
""" | |
Args: | |
x: [B, patch_num, patch_size, C] | |
parts: [B, N, C] | |
part_qpos: [B, N, 1, C] | |
part_kpos: [B, N, 1, C] | |
mask: [B, 1, patch_num, patch_size] if exists, else None | |
Returns: | |
feats: [B, patch_num, patch_size, C] | |
parts: [B, N, C] | |
part_qpos: [B, N, 1, C] | |
mask: [B, 1, patch_num, patch_size] if exists, else None | |
""" | |
P = x.shape[1] | |
x = rearrange(x, "b p k c -> b (p k) c") | |
feats = self.decoder(x, parts=parts, qpos=qpos, kpos=kpos, mask=mask, P=P) | |
return feats, parts, qpos, mask | |
class Stage(nn.Module): | |
def __init__(self, in_ch, out_ch, num_blocks, patch_size=7, num_heads=1, num_enc_heads=1, stride=1, num_parts=0, | |
last_np=0, last_enc=False, drop_path=0.1, has_mask=None, ffn_exp=3): | |
super(Stage, self).__init__() | |
if isinstance(drop_path, float): | |
drop_path = [drop_path for _ in range(num_blocks)] | |
self.patch_size = patch_size | |
self.rpn_qpos = nn.Parameter(torch.Tensor(1, num_parts, 1, out_ch // num_heads)) | |
self.rpn_kpos = nn.Parameter(torch.Tensor(1, num_parts, 1, out_ch // num_heads)) | |
self.proj_p = PatchEmbed(stride, has_mask = has_mask, in_ch=in_ch, out_ch=out_ch) | |
self.proj_x = PatchEmbed(stride, has_mask = has_mask, in_ch=in_ch, out_ch=out_ch) | |
# self.proj_token = nn.Sequential( | |
# nn.Conv1d(last_np, num_parts, 1, bias=False) if last_np != num_parts else nn.Identity(), | |
# nn.Linear(in_ch, out_ch), | |
# Norm(out_ch) | |
# ) | |
self.proj_token = None | |
self.proj_norm = Norm(out_ch) | |
blocks = [ | |
TAGBlock(out_ch, | |
patch_size=patch_size, | |
num_heads=num_heads, | |
num_enc_heads=num_enc_heads, | |
num_parts=num_parts, | |
ffn_exp=ffn_exp, | |
drop_path=drop_path[i]) | |
for i in range(num_blocks) | |
] | |
self.blocks = nn.ModuleList(blocks) | |
self.last_enc = Encoder(dim=out_ch, | |
num_enc_heads=num_enc_heads, | |
num_parts=num_parts, | |
drop_path=drop_path[-1], | |
has_ffn=False) if last_enc else None | |
self._init_weights() | |
def _init_weights(self): | |
init.kaiming_uniform_(self.rpn_qpos, a=math.sqrt(5)) | |
trunc_normal_(self.rpn_qpos, std=.02) | |
init.kaiming_uniform_(self.rpn_kpos, a=math.sqrt(5)) | |
trunc_normal_(self.rpn_kpos, std=.02) | |
def to_patch(self, x, patch_size, H, W, mask=None): | |
x = rearrange(x, "b (h w) c -> b h w c", h=H) | |
pad_l = pad_t = 0 | |
pad_r = int(math.ceil(W / patch_size)) * patch_size - W | |
pad_b = int(math.ceil(H / patch_size)) * patch_size - H | |
x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) | |
if mask is not None: | |
mask = F.pad(mask, (pad_l, pad_r, pad_t, pad_b), value=1) | |
x = rearrange(x, "b (sh kh) (sw kw) c -> b (sh sw) (kh kw) c", kh=patch_size, kw=patch_size) | |
if mask is not None: | |
mask = rearrange(mask, "b c (sh kh) (sw kw) -> b c (kh kw) (sh sw)", kh=patch_size, kw=patch_size) | |
return x, mask, H + pad_b, W + pad_r | |
def to_part(self, x, mask=None): | |
x, H, W, mask = self.proj_p(x, mask=mask) | |
x = self.proj_norm(x) | |
if self.proj_token is not None: | |
parts = self.proj_token(parts) | |
ori_H, ori_W = H, W | |
x, mask, H, W = self.to_patch(x, self.patch_size, H, W, mask) | |
P = x.shape[1] | |
x = rearrange(x, "b p k c -> b (p k) c") | |
return x | |
def forward(self, x, p, mask=None): | |
""" | |
Args: | |
x: [B, C, H, W] | |
parts: [B, N, C] | |
mask: [B, 1, H, W] if exists, else None | |
Returns: | |
x: [B, out_C, out_H, out_W] | |
parts: [B, out_N, out_C] | |
mask: [B, 1, out_H, out_W] if exists else None | |
""" | |
parts = self.to_part(p, mask = mask) | |
x, H, W, mask = self.proj_x(x, mask=mask) | |
x = self.proj_norm(x) | |
if self.proj_token is not None: | |
parts = self.proj_token(parts) | |
rpn_qpos, rpn_kpos = self.rpn_qpos, self.rpn_kpos | |
rpn_qpos = rpn_qpos.expand(x.shape[0], -1, -1, -1) | |
rpn_kpos = rpn_kpos.expand(x.shape[0], -1, -1, -1) | |
ori_H, ori_W = H, W | |
x, mask, H, W = self.to_patch(x, self.patch_size, H, W, mask) | |
for blk in self.blocks: | |
# x: [B, K, P, C] | |
x, parts, rpn_qpos, mask = blk(x, | |
parts=parts, | |
qpos=rpn_qpos, | |
kpos=rpn_kpos, | |
mask=mask) | |
dec_mask = None if mask is None else rearrange(mask.squeeze(1), "b h w -> b 1 1 (h w)") | |
if self.last_enc is not None: | |
x = rearrange(x, "b p k c -> b (p k) c") | |
rpn_out = self.last_enc(x, parts=parts, qpos=rpn_qpos, mask=dec_mask) | |
return rpn_out | |
else: | |
x = rearrange(x, "b (sh sw) (kh kw) c -> b c (sh kh) (sw kw)", kh=self.patch_size, sh=H // self.patch_size) | |
x = x[:, :, :ori_H, :ori_W] | |
return x | |
class TAG(nn.Module): | |
def __init__(self, | |
in_chans=3, | |
inplanes=64, | |
num_layers=(3, 4, 6, 3), | |
num_chs=(256, 512, 1024, 2048), | |
num_strides=(1, 2, 2, 2), | |
num_classes=1000, | |
num_heads=(1, 1, 1, 1), | |
num_parts=(1, 1, 1, 1), | |
patch_sizes=(1, 1, 1, 1), | |
drop_path=0.1, | |
num_enc_heads=(1, 1, 1, 1), | |
act=nn.GELU, | |
ffn_exp=3, | |
no_pos_wd=False, | |
has_last_encoder=False, | |
pretrained=False, | |
**ret_args): | |
super(TAG, self).__init__() | |
self.depth = len(num_layers) | |
self.no_pos_wd = no_pos_wd | |
self.conv1 = nn.Conv2d(in_chans, inplanes, kernel_size=7, padding=3, stride=2, bias=False) | |
self.norm1 = nn.BatchNorm2d(inplanes) | |
self.act = act() | |
self.pool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) | |
self.rpn_tokens = nn.Parameter(torch.Tensor(1, num_parts[0], inplanes)) | |
drop_path_ratios = torch.linspace(0, drop_path, sum(num_layers)) | |
last_chs = [inplanes, *num_chs[:-1]] | |
last_nps = [num_parts[0], *num_parts[:-1]] | |
for i, n_l in enumerate(num_layers): | |
stage_ratios = [drop_path_ratios[sum(num_layers[:i]) + did] for did in range(n_l)] | |
setattr(self, | |
"layer_{}".format(i), | |
Stage(last_chs[i], | |
num_chs[i], | |
n_l, | |
stride=num_strides[i], | |
num_heads=num_heads[i], | |
num_enc_heads=num_enc_heads[i], | |
patch_size=patch_sizes[i], | |
drop_path=stage_ratios, | |
ffn_exp=ffn_exp, | |
num_parts=num_parts[i], | |
last_np=last_nps[i], | |
last_enc=has_last_encoder and i == len(num_layers) - 1) | |
) | |
if has_last_encoder: | |
self.last_fc = nn.Linear(num_chs[-1], num_classes) | |
else: | |
self.last_linear = nn.Conv2d(num_chs[-1], num_chs[-1], kernel_size=1, bias=False) | |
self.last_norm = nn.BatchNorm2d(num_chs[-1]) | |
self.pool2 = nn.AdaptiveAvgPool2d(1) | |
self.last_fc = nn.Linear(num_chs[-1], num_classes) | |
self.has_last_encoder = has_last_encoder | |
self._init_weights(pretrained=pretrained) | |
def no_weight_decay(self): | |
skip_pattern = ['rel_pos'] if self.no_pos_wd else [] | |
no_wd_layers = set() | |
for name, param in self.named_parameters(): | |
for skip_name in skip_pattern: | |
if skip_name in name: | |
no_wd_layers.add(name) | |
return no_wd_layers | |
def _init_weights(self, pretrained=None): | |
if isinstance(pretrained, str): | |
state_dict = torch.load(pretrained, map_location=torch.device("cpu")) | |
if "state_dict" in state_dict.keys(): | |
state_dict = state_dict["state_dict"] | |
self.load_state_dict(state_dict, strict=True) | |
return | |
init.kaiming_uniform_(self.rpn_tokens, a=math.sqrt(5)) | |
trunc_normal_(self.rpn_tokens, std=.02) | |
for m in self.modules(): | |
if isinstance(m, nn.Conv2d): | |
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels | |
m.weight.data.normal_(0, math.sqrt(2. / n)) | |
trunc_normal_(m.weight, std=.02) | |
if m.bias is not None: | |
nn.init.constant_(m.bias, 0) | |
elif isinstance(m, nn.Conv1d): | |
n = m.kernel_size[0] * m.out_channels | |
m.weight.data.normal_(0, math.sqrt(2. / n)) | |
trunc_normal_(m.weight, std=.02) | |
if m.bias is not None: | |
nn.init.constant_(m.bias, 0) | |
elif isinstance(m, (nn.BatchNorm2d, nn.BatchNorm1d)): | |
if not torch.sum(m.weight.data == 0).item() == m.num_features: # zero gamma | |
m.weight.data.fill_(1) | |
m.bias.data.zero_() | |
elif isinstance(m, nn.Linear): | |
trunc_normal_(m.weight, std=.02) | |
if m.bias is not None: | |
nn.init.constant_(m.bias, 0) | |
elif isinstance(m, nn.LayerNorm): | |
nn.init.constant_(m.bias, 0) | |
nn.init.constant_(m.weight, 1.0) | |
def forward(self, x): | |
out = self.conv1(x) | |
out = self.norm1(out) | |
out = self.act(out) | |
out = self.pool1(out) | |
B, _, H, W = out.shape | |
rpn_tokens, mask = self.rpn_tokens.expand(x.shape[0], -1, -1), None | |
for i in range(self.depth): | |
layer = getattr(self, "layer_{}".format(i)) | |
out, rpn_tokens, mask = layer(out, rpn_tokens, mask=mask) | |
if self.has_last_encoder: | |
out = self.act(out) | |
out = out.mean(1) | |
else: | |
out = self.last_linear(out) | |
out = self.last_norm(out) | |
out = self.act(out) | |
out = self.pool2(out) | |
out = out.squeeze() | |
out = self.last_fc(out).squeeze() | |
return out.view(out.size(0), -1) | |
def TAG_mobile(pretrained=False, **cfg): | |
model_cfg = dict(inplanes=64, num_chs=(48, 96, 192, 384), patch_sizes=[8, 7, 7, 7], num_heads=[1, 2, 4, 8], | |
num_enc_heads=[1, 2, 4, 8], num_parts=[16, 16, 16, 32], num_layers=[1, 1, 1, 1], ffn_exp=3, | |
has_last_encoder=True, drop_path=0., **cfg) | |
return TAG(pretrained=pretrained, **model_cfg) | |
def TAG_tiny(pretrained=False, **cfg): | |
model_cfg = dict(inplanes=64, num_chs=(64, 128, 256, 512), patch_sizes=[8, 7, 7, 7], num_heads=[1, 2, 4, 8], | |
num_enc_heads=[1, 2, 4, 8], num_parts=[32, 32, 32, 32], num_layers=[1, 1, 2, 1], ffn_exp=3, | |
has_last_encoder=True, drop_path=0.1, **cfg) | |
return TAG(pretrained=pretrained, **model_cfg) | |
def TAG_small(pretrained=False, **cfg): | |
model_cfg = dict(inplanes=64, num_chs=(96, 192, 384, 768), patch_sizes=[8, 7, 7, 7], num_heads=[3, 6, 12, 24], | |
num_enc_heads=[1, 3, 6, 12], num_parts=[64, 64, 64, 64], num_layers=[1, 1, 3, 1], ffn_exp=3, | |
has_last_encoder=True, drop_path=0.1, **cfg) | |
return TAG(pretrained=pretrained, **model_cfg) | |
def TAG_medium(pretrained=False, **cfg): | |
model_cfg = dict(inplanes=64, num_chs=(96, 192, 384, 768), patch_sizes=[8, 7, 7, 7], num_heads=[3, 6, 12, 24], | |
num_enc_heads=[1, 3, 6, 12], num_parts=[64, 64, 64, 128], num_layers=[1, 1, 8, 1], ffn_exp=3, | |
has_last_encoder=False, drop_path=0.2, **cfg) | |
return TAG(pretrained=pretrained, **model_cfg) | |
def TAG_base(pretrained=False, **cfg): | |
model_cfg = dict(inplanes=64, num_chs=(128, 256, 512, 1024), patch_sizes=[8, 7, 7, 7], num_heads=[4, 8, 16, 32], | |
num_enc_heads=[1, 4, 8, 16], num_parts=[64, 64, 128, 128], num_layers=[1, 1, 8, 1], ffn_exp=3, | |
has_last_encoder=False, drop_path=0.3, **cfg) | |
return TAG(pretrained=pretrained, **model_cfg) | |